In [1]:
from IPython.display import Audio
from datasets import load_dataset

### Prepare dataset

In [2]:
cv13 = load_dataset('mozilla-foundation/common_voice_13_0', 'yue', trust_remote_code=True)
urban = load_dataset('danavery/urbansound8K')

In [3]:
from MangoDemo.mangoPURE.data.mixer import DatasetMixer
from MangoDemo.mangoPURE.data.transforms import *
from MangoDemo.mangoPURE.data.providers import *

In [4]:
speaker_provider = CV13Random(cv13["test"])
noise_provider = UrbanRandom(urban["train"])

In [5]:
trans = [
    CreateRandomBlankAudio(),
    AddSeveralRandomNoiseSegments(noise_provider),
    MergeAll()
]
mixer = DatasetMixer(
    trans,
)

### Test collator

In [6]:
from MangoDemo.mangoPURE.models.collators import WhisperToTimedBatch, WhisperToTimedBatchConfig
from transformers import WhisperFeatureExtractor

In [7]:
batch_list = [mixer.generate() for _ in range(3)]

We should know some data about the model
- num of timestamps - 1500
- num of classes to predict

In [8]:
feature_extractor = WhisperFeatureExtractor.from_pretrained("openai/whisper-tiny")
config = WhisperToTimedBatchConfig(
    create_labels=True,
    noise_classes=11,
    output_timestamps=1500,
    diar_type="noise"
)
collator = WhisperToTimedBatch(feature_extractor, config)

In [9]:
batch_timed = collator(batch_list)

In [10]:
batch_timed.keys()

dict_keys(['input_features', 'attention_mask', 'labels'])

In [11]:
batch_timed["labels"].shape

torch.Size([3, 1500, 11])

### Test whisper timed model & timed loss
The loss is inside the wrapper

In [12]:
from MangoDemo.mangoPURE.models.modules import WhisperEmbedder, LinearTimedHead, WhisperTimedModel
from MangoDemo.mangoPURE.models.metrics import SigmoidTimedLoss

Here we should prior two facts:
- whisper-tiny output embedding dim for each timestamp is 384
- utban has 10 classes (11 = 10 + 1, because of blank class)

In [13]:
embedder = WhisperEmbedder("openai/whisper-tiny")
head = LinearTimedHead(384, 11)
loss_fn = SigmoidTimedLoss() 

In [14]:
wrapper = WhisperTimedModel(
    embedder=embedder,
    head=head,
    loss_fn=loss_fn
)

In [15]:
output = wrapper(batch_timed) 

In [16]:
output["loss"]

tensor(0.7793, grad_fn=<DivBackward0>)