In [1]:
import os
import numpy as np
import pandas as pd

import torch
import torchaudio
import torch.nn.functional as F
from torch import nn
import whisper
from whisper.audio import (
    log_mel_spectrogram,
    pad_or_trim,
)
import jiwer
from tqdm import tqdm
from main import *

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
model = whisper.load_model("tiny.en")
print(
    # f"Model is {'multilingual' if model.is_multilingual else 'English-only'} "
    f"Model has {sum(np.prod(p.shape) for p in model.parameters()):,} parameters."
)

Model has 37,184,256 parameters.


In [3]:
def collect_params(model):
    # collect trainable params
    params = []
    names = []

    for name, param in model.named_parameters():
        param.requires_grad = False

    for nm, m in model.named_modules():
        trainable = ['weight', 'bias']
        # train_LN
        if isinstance(m, nn.LayerNorm) and str(nm).split('.')[0] == 'encoder':
            for np, p in m.named_parameters():
                if np in trainable:  
                    p.requires_grad = True
                    params.append(p)
                    names.append(f"{nm}.{np}")
        # train_feature
        if len(str(nm).split('.')) > 1:
            if str(nm).split('.')[0] == 'encoder' and (str(nm).split('.')[1] == 'conv1' or str(nm).split('.')[1] == 'conv2'):
                for np, p in m.named_parameters():
                    p.requires_grad = True
                    params.append(p)
                    names.append(f"{nm}.{np}")

    return params, names

In [4]:
def forward_and_adapt(x, model, optimizer, em_coef=1.0, reweight=False, temp=1., not_blank=True, scheduler=None, 
                        div_coef=0, repeat_inference=True, skip_short_thd=None):
    """Forward and adapt model on batch of data.

    Measure entropy of the model prediction, take gradients, and update params.

    the index of <pad> in vocab is 0
    """
    # forward
    outputs = model.decode(x, options)
    logits = torch.stack(outputs[1], dim=0)
    logits=logits.permute(1,0,2) # torch.Size([1, 5, 51864])
    # adapt
    loss = 0

    if em_coef > 0: 
        e_loss = softmax_entropy(logits / temp).mean(0).mean() 
        
        loss += e_loss * em_coef

    if 1 - em_coef > 0: 
        c_loss = mcc_loss(logits / temp, reweight)
        loss += c_loss * (1 - em_coef)

    if div_coef > 0: 
        d_loss = div_loss(logits, not_blank) 
        loss += d_loss * div_coef 

    loss.backward()
    optimizer.step()
    if scheduler is not None: 
        scheduler.step()
    model.zero_grad()

    # inference again
    if repeat_inference:
        with torch.no_grad():
            outputs = model.decode(x, options)
    return outputs

In [5]:
params, names = collect_params(model)
model = model.to(DEVICE)
options = whisper.DecodingOptions(language="en", without_timestamps=True)
optimizer, scheduler = setup_optimizer(params, 'AdamW', lr=3e-4, scheduler=None)

[INFO]    optimizer: <class 'torch.optim.adamw.AdamW'>
[INFO]    scheduler: None


In [6]:
from data import load_dataset
dataset = load_dataset(split=['test-other'], name='librispeech', path='../LibriSpeech', batch_size=1, extra_noise=0.01)

Read text:   0%|          | 0/2939 [00:00<?, ?it/s]

Read text: 100%|██████████| 2939/2939 [00:00<00:00, 70249.66it/s]

[INFO]    There are 2939 samples.





In [7]:
transcriptions = []
ori_transcriptions = []
model_state, optimizer_state, scheduler_state = copy_model_and_optimizer(model, optimizer, scheduler)
for batch in tqdm(dataset):
    lens, wavs, texts, files = batch
    wavs = pad_or_trim(wavs[0])
    mel = log_mel_spectrogram(wavs)
    mel = mel.unsqueeze(-1)
    mel = mel.permute(2,0,1).to(DEVICE)
    outputs = model.decode(mel, options)
    model, optimizer, scheduler = load_model_and_optimizer(model, optimizer, scheduler, model_state, optimizer_state, scheduler_state)
    for i in range(10):
        adapt_output = forward_and_adapt(mel, model, optimizer)
    transcriptions.append(adapt_output[0][0].text)
    ori_transcriptions.append(texts[0])
    del outputs, adapt_output
    torch.cuda.empty_cache()

  0%|          | 0/2939 [00:00<?, ?it/s]

  0%|          | 0/2939 [00:01<?, ?it/s]


OutOfMemoryError: CUDA out of memory. Tried to allocate 38.00 MiB. GPU 0 has a total capacty of 11.72 GiB of which 48.69 MiB is free. Including non-PyTorch memory, this process has 11.29 GiB memory in use. Of the allocated memory 11.01 GiB is allocated by PyTorch, and 84.87 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

In [8]:
transcriptions = []
ori_transcriptions = []
model_state, optimizer_state, scheduler_state = copy_model_and_optimizer(model, optimizer, scheduler)
count = 0
for mels, texts in tqdm(loader):
    count+=1
    if count > 130:
        break
    outputs = model.decode(mels, options)
    model, optimizer, scheduler = load_model_and_optimizer(model, optimizer, scheduler, model_state, optimizer_state, scheduler_state)
    for i in range(10):
        adapt_output = forward_and_adapt(mels, model, optimizer)
    transcriptions.append(adapt_output[0][0].text)
    ori_transcriptions.append(texts[0])
    del outputs, adapt_output
    torch.cuda.empty_cache()

  5%|▍         | 130/2620 [04:23<1:24:03,  2.03s/it]


In [9]:
data = pd.DataFrame(dict(hypothesis=transcriptions, reference=ori_transcriptions))
data

Unnamed: 0,hypothesis,reference
0,"He hoped there would be stew for dinner, turni...",HE HOPED THERE WOULD BE STEW FOR DINNER TURNIP...
1,Stuff it into you. His belly counseled him.,STUFF IT INTO YOU HIS BELLY COUNSELLED HIM
2,"After early nightfall, the yellow lamps would ...",AFTER EARLY NIGHTFALL THE YELLOW LAMPS WOULD L...
3,"Hello, Bertie. Any good in your mind?",HELLO BERTIE ANY GOOD IN YOUR MIND
4,"Number 10, fresh Nelly is waiting on you. Good...",NUMBER TEN FRESH NELLY IS WAITING ON YOU GOOD ...
...,...,...
125,I'm me.,AY ME
126,"No matter then, although my foot did stand, up...",NO MATTER THEN ALTHOUGH MY FOOT DID STAND UPON...
127,"Thought kills me that I am not thought, to lea...",THOUGHT KILLS ME THAT I AM NOT THOUGHT TO LEAP...
128,My heart-doth plead that thou in him-doth lie....,MY HEART DOTH PLEAD THAT THOU IN HIM DOST LIE ...


In [10]:
import jiwer
from whisper.normalizers import EnglishTextNormalizer

normalizer = EnglishTextNormalizer()

In [11]:
data["hypothesis_clean"] = [normalizer(text) for text in data["hypothesis"]]
data["reference_clean"] = [normalizer(text) for text in data["reference"]]
data

Unnamed: 0,hypothesis,reference,hypothesis_clean,reference_clean
0,"He hoped there would be stew for dinner, turni...",HE HOPED THERE WOULD BE STEW FOR DINNER TURNIP...,he hoped there would be stew for dinner turnip...,he hoped there would be stew for dinner turnip...
1,Stuff it into you. His belly counseled him.,STUFF IT INTO YOU HIS BELLY COUNSELLED HIM,stuff it into you his belly counseled him,stuff it into you his belly counseled him
2,"After early nightfall, the yellow lamps would ...",AFTER EARLY NIGHTFALL THE YELLOW LAMPS WOULD L...,after early nightfall the yellow lamps would l...,after early nightfall the yellow lamps would l...
3,"Hello, Bertie. Any good in your mind?",HELLO BERTIE ANY GOOD IN YOUR MIND,hello bertie any good in your mind,hello bertie any good in your mind
4,"Number 10, fresh Nelly is waiting on you. Good...",NUMBER TEN FRESH NELLY IS WAITING ON YOU GOOD ...,number 10 fresh nelly is waiting on you good n...,number 10 fresh nelly is waiting on you good n...
...,...,...,...,...
125,I'm me.,AY ME,i am me,ay me
126,"No matter then, although my foot did stand, up...",NO MATTER THEN ALTHOUGH MY FOOT DID STAND UPON...,no matter then although my foot did stand upon...,no matter then although my foot did stand upon...
127,"Thought kills me that I am not thought, to lea...",THOUGHT KILLS ME THAT I AM NOT THOUGHT TO LEAP...,thought kills me that i am not thought to leap...,thought kills me that i am not thought to leap...
128,My heart-doth plead that thou in him-doth lie....,MY HEART DOTH PLEAD THAT THOU IN HIM DOST LIE ...,my heart doth plead that thou in him doth lie ...,my heart doth plead that thou in him dost lie ...


In [13]:
data.to_csv('./test.csv')

In [12]:
wer = jiwer.wer(list(data["reference_clean"]), list(data["hypothesis_clean"]))

print(f"WER: {wer * 100:.2f} %")

WER: 4.72 %


In [33]:
wer = jiwer.wer(list(data["reference_clean"]), list(data["hypothesis_clean"]))

print(f"WER: {wer * 100:.2f} %")

WER: 4.27 %


In [41]:
wer = jiwer.wer(list(data["reference_clean"]), list(data["hypothesis_clean"]))

print(f"WER: {wer * 100:.2f} %")

WER: 3.05 %
