In [1]:
from IPython.display import Audio
from datasets import load_dataset

### Prepare dataset

In [2]:
cv13 = load_dataset('mozilla-foundation/common_voice_13_0', 'yue', trust_remote_code=True)
urban = load_dataset('danavery/urbansound8K')

In [3]:
from MangoDemo.mangoPURE.data.mixer import DatasetMixer
from MangoDemo.mangoPURE.data.transforms import *
from MangoDemo.mangoPURE.data.providers import *

In [4]:
speaker_provider = CV13Random(cv13["test"])
noise_provider = UrbanRandom(urban["train"])

In [5]:
trans = [
    CreateRandomBlankAudio(),
    AddSeveralRandomNoiseSegments(noise_provider),
    MergeAll()
]
mixer = DatasetMixer(
    trans,
)

### Test collator

In [6]:
from MangoDemo.mangoPURE.models.collators import WhisperToTimedBatch
from transformers import WhisperFeatureExtractor

In [7]:
batch_list = [mixer.generate() for _ in range(3)]

In [8]:
feature_extractor = WhisperFeatureExtractor.from_pretrained("openai/whisper-tiny")
collator = WhisperToTimedBatch(feature_extractor)

In [9]:
batch_timed = collator(batch_list)
batch_timed

TimedAudioBatch(audio=tensor([[[ 0.7794,  0.4041,  0.0346,  ..., -0.7226, -0.7226, -0.7226],
         [ 0.7595,  0.6332,  0.4977,  ..., -0.7226, -0.7226, -0.7226],
         [ 0.5897,  0.6426,  0.7625,  ..., -0.7226, -0.7226, -0.7226],
         ...,
         [-0.2645, -0.2896, -0.2377,  ..., -0.7226, -0.7226, -0.7226],
         [-0.4441, -0.1845, -0.2053,  ..., -0.7226, -0.7226, -0.7226],
         [-0.4526, -0.2979, -0.3516,  ..., -0.7226, -0.7226, -0.7226]],

        [[-0.8877, -0.8877, -0.8877,  ..., -0.8877, -0.8877, -0.8877],
         [-0.8877, -0.8877, -0.8877,  ..., -0.8877, -0.8877, -0.8877],
         [-0.8877, -0.8877, -0.8877,  ..., -0.8877, -0.8877, -0.8877],
         ...,
         [-0.8877, -0.8877, -0.8877,  ..., -0.8877, -0.8877, -0.8877],
         [-0.8877, -0.8877, -0.8877,  ..., -0.8877, -0.8877, -0.8877],
         [-0.8877, -0.8877, -0.8877,  ..., -0.8877, -0.8877, -0.8877]],

        [[-0.9167, -0.9167, -0.9167,  ..., -0.9167, -0.9167, -0.9167],
         [-0.9167, -0.9

### Test whisper timed model & timed loss
The loss is inside the wrapper

In [10]:
from MangoDemo.mangoPURE.models.modules import WhisperEmbedder, LinearTimedHead
from MangoDemo.mangoPURE.models.wrappers import WhisperTimedWrapper
from MangoDemo.mangoPURE.models.metrics import SigmoidTimedLoss

Here we should prior two facts:
- whisper-tiny output embedding dim for each timestamp is 384
- utban has 10 classes (11 = 10 + 1, because of blank class)

In [11]:
embedder = WhisperEmbedder("openai/whisper-tiny")
head = LinearTimedHead(384, 11)
loss_fn = SigmoidTimedLoss() 

In [12]:
wrapper = WhisperTimedWrapper(
    embedder=embedder,
    head=head,
    loss_fn=loss_fn
)

In [13]:
output = wrapper(batch_timed) 

In [14]:
output.loss

tensor(0.7726, grad_fn=<DivBackward0>)