In [1]:
from mangoPURE.data import DatasetMixer
from mangoPURE.data.mixer import TorchDatasetWrapper
from mangoPURE.data.transforms import CreateRandomBlankAudio, AddRandomFilledNoise, MergeAll
from mangoPURE.data.providers import UrbanRandom

In [2]:
from datasets import load_dataset

urban_provider = UrbanRandom(load_dataset('danavery/urbansound8K', split='train'))
data = DatasetMixer([CreateRandomBlankAudio(), AddRandomFilledNoise(urban_provider), MergeAll()])

In [3]:
dataset = TorchDatasetWrapper(data, 1500)

In [4]:
from torch.utils.data import DataLoader
from mangoPURE.models.collators import OneNoiseCollator
from transformers import WhisperFeatureExtractor

extractor = WhisperFeatureExtractor.from_pretrained('openai/whisper-tiny')
loader = DataLoader(dataset, 4, collate_fn=OneNoiseCollator(extractor))

In [5]:
batch = next(iter(loader))

In [6]:
from mangoPURE.models.modules import WhisperTimedModel

In [7]:
from mangoPURE.models.modules import WhisperEmbedder, LinearSoloHead
from mangoPURE.models.metrics import CrossEntropyLoss

model = WhisperTimedModel(WhisperEmbedder('openai/whisper-tiny'), LinearSoloHead(384, 11), CrossEntropyLoss())

In [9]:
model(**batch)

{'input_features': tensor([[[ 0.5568,  0.5941,  0.4181,  ..., -1.2755, -1.2755, -1.2755],
          [ 0.6073,  0.5266,  0.3078,  ..., -1.2755, -1.2755, -1.2755],
          [ 0.4225,  0.3603,  0.2589,  ..., -1.2755, -1.2755, -1.2755],
          ...,
          [-0.5719, -0.8339, -0.7949,  ..., -1.2755, -1.2755, -1.2755],
          [-0.6031, -0.8243, -0.8311,  ..., -1.2755, -1.2755, -1.2755],
          [-0.5545, -0.9321, -0.9883,  ..., -1.2755, -1.2755, -1.2755]],
 
         [[ 0.7026,  0.7514,  0.7443,  ..., -0.6381, -0.6381, -0.6381],
          [ 0.6897,  0.7349,  0.7671,  ..., -0.6381, -0.6381, -0.6381],
          [ 0.5224,  0.6129,  0.7306,  ..., -0.6381, -0.6381, -0.6381],
          ...,
          [-0.1853, -0.2337, -0.1636,  ..., -0.6381, -0.6381, -0.6381],
          [-0.2660, -0.3524, -0.2340,  ..., -0.6381, -0.6381, -0.6381],
          [-0.4477, -0.4827, -0.3993,  ..., -0.6381, -0.6381, -0.6381]],
 
         [[ 0.7889,  0.8189,  0.7841,  ..., -0.9165, -0.9165, -0.9165],
          

In [9]:
from mango.training.MangoTrainer import MangoTrainer
from mango.training.MangoTrainer import TrainerConfig
from mango.training.trackers import NeptuneTracker

In [10]:
config = TrainerConfig('whisper-solo-clf', logs_frequency_batches=8, save_strategy='epoch', early_stopping_patience=3)

tracker = NeptuneTracker('mango/mango-noise',
                         'eyJhcGlfYWRkcmVzcyI6Imh0dHBzOi8vYXBwLm5lcHR1bmUuYWkiLCJhcGlfdXJsIjoiaHR0cHM6Ly9hcHAubmVwdHVuZS5haSIsImFwaV9rZXkiOiJhODNmZWQ5NS1hNDRiLTRiZDUtODJhYS1jYmRiZTQ0MDkzNDQifQ==',
                         'whisper-solo-clf', tags=[
        'solo-noise',
        'whisper-tiny',
        'urban-sound',
    ])

trainer = MangoTrainer(
    model,
    loader,
    loader,
    config,
    trackers=[tracker]
)



[neptune] [info   ] Neptune initialized. Open in the app: https://app.neptune.ai/mango/mango-noise/e/MNGNS-20


In [11]:
trainer.train(1, None)

train:   0%|          | 0/375 [00:00<?, ?it/s]

eval:   0%|          | 0/375 [00:00<?, ?it/s]

RuntimeError: MPS backend out of memory (MPS allocated: 11.51 GB, other allocations: 7.53 GB, max allowed: 18.13 GB). Tried to allocate 64.50 KB on private pool. Use PYTORCH_MPS_HIGH_WATERMARK_RATIO=0.0 to disable upper limit for memory allocations (may cause system failure).