In [1]:
import os

import accelerate
import optuna

In [2]:
from mangoPURE.data.augment import AugmentationTransform

In [3]:
from mangoPURE.data import DatasetMixer
from mangoPURE.data.mixer import TorchDatasetWrapper
from mangoPURE.data.transforms import CreateRandomBlankAudio, AddRandomFilledNoise, MergeAll
from mangoPURE.data.providers import UrbanRandom
from datasets import load_dataset
from mangoPURE.models.collators import OneNoiseCollator
from transformers import WhisperFeatureExtractor, WhisperConfig
from torch.utils.data import DataLoader
from mangoPURE.models.modules import WhisperTimedModel
from mangoPURE.models.modules import WhisperEmbedder, LinearSoloHead
from mangoPURE.models.metrics import CrossEntropyLoss
from mango.training.MangoTrainer import MangoTrainer
from mango.training.MangoTrainer import TrainerConfig
import accelerate
import torch

In [4]:
urban_provider = UrbanRandom(load_dataset('danavery/urbansound8K', split='train'))

In [5]:
from dataclasses import dataclass
import torch
from sklearn.metrics import f1_score
from mango.utils.multilabel import render_confusion_matrix_sololabel

@dataclass
class ClassificationAccuracy:
    threshold: float = 0.5

    def __call__(self, train_output):
        preds = train_output.model_outputs["head_output"].argmax(dim=-1).cpu()
        labels = train_output.model_outputs["labels"].int().cpu()
        result = preds == labels
        accuracy = result.int().sum() / torch.numel(result)

        return {"accuracy": float(accuracy.numpy()),
                "f1_macro": float(f1_score(y_true=labels, y_pred=preds, average='macro')),
                "f1_micro": float(f1_score(y_true=labels, y_pred=preds, average='micro')),
                'conf_matrix' : render_confusion_matrix_sololabel(labels, preds)
                }

In [6]:
def get_data(param):
    train_transforms = [CreateRandomBlankAudio(), AddRandomFilledNoise(urban_provider), MergeAll()]
    test_transforms = [CreateRandomBlankAudio(), AddRandomFilledNoise(urban_provider), MergeAll()]
    if param['use_augment']:
        train_transforms.append(AugmentationTransform())
    train_data = DatasetMixer(train_transforms)
    test_data = DatasetMixer(test_transforms)

    train_dataset = TorchDatasetWrapper(train_data, 15)
    val_dataset = TorchDatasetWrapper(test_data, 15)

    return train_dataset, val_dataset


checkpoint_to_batch = {
    'whisper-tiny': 64,
    'whisper-base': 48,
    'whisper-small': 24
}

available_models = ['whisper-tiny']#['whisper-tiny', 'whisper-base', 'whisper-small']

accelerator = accelerate.Accelerator(mixed_precision='fp16')

def objective(trial):
    param = {
        "checkpoint": trial.suggest_categorical("checkpoint", available_models),
        "use_augment": trial.suggest_categorical('use_augment', [False, True]),
        'lr': trial.suggest_float('lr', 4e-6, 1e-3, log=True),
        'momentum': trial.suggest_float('momentum', 0, 1, step=0.1),
        'fp16': True,  # Unclear what to do here... trial.suggest_categorical('fp16', [False, True]),
        'grad_clip': trial.suggest_categorical('grad_clip', [False, True])
    }

    train, val = get_data(param)
    checkpoint = f'openai/{param["checkpoint"]}'

    whisper_config = WhisperConfig.from_pretrained(checkpoint)
    extractor = WhisperFeatureExtractor.from_pretrained(checkpoint)
    batch_size = checkpoint_to_batch[param["checkpoint"]]
    if not param['fp16']:
        batch_size /= 2

    train_loader = DataLoader(train, batch_size, collate_fn=OneNoiseCollator(extractor), num_workers=2,
                              prefetch_factor=4)
    val_loader = DataLoader(val, batch_size, collate_fn=OneNoiseCollator(extractor), num_workers=2, prefetch_factor=4)

    model = WhisperTimedModel(WhisperEmbedder(checkpoint), LinearSoloHead(whisper_config.d_model, 11),
                              CrossEntropyLoss())

    config = TrainerConfig('whisper-solo-clf',
                           logs_frequency_batches=1000000,
                           save_strategy='never',
                           early_stopping_patience=20,
                           push_to_hub=False,
                           mixed_precision='fp16',
                           grad_clip=param['grad_clip'])

    optim = torch.optim.AdamW(model.parameters(), lr=param['lr'], weight_decay=1e-5)
    scheduler = torch.optim.lr_scheduler.CyclicLR(optim, base_lr=param['lr'], max_lr=param['lr'] * 10, mode='exp_range',
                                                  gamma=param['momentum'], cycle_momentum=False,
                                                  step_size_up=int(len(train_loader) * 1.75),
                                                  step_size_down=int(len(train_loader) * 1.25))

    trainer = MangoTrainer(
        model,
        train_loader,
        val_loader,
        config,
        accelerator=accelerator,
        optimizer=optim,
        scheduler=scheduler,
        trackers=[]
    )
    
    trainer.train(1, None)
    
    eval_output = trainer.eval_iteration(1)
    
    return ClassificationAccuracy()(eval_output)['accuracy']

In [7]:
import neptune
import os 

run = neptune.init_run(project='mango/mango-sweep', api_token=os.getenv('NEPTUNE_TOKEN'))



[neptune] [info   ] Neptune initialized. Open in the app: https://app.neptune.ai/mango/mango-sweep/e/MNGSWP-3


In [8]:
import neptune.integrations.optuna as npt_utils

neptune_callback = npt_utils.NeptuneCallback(run)

In [9]:
study = optuna.create_study(direction="maximize")
study.optimize(objective, n_trials=100, callbacks=[neptune_callback])

[I 2024-04-08 23:24:15,916] A new study created in memory with name: no-name-fa37da4f-8e83-4ab5-a6bc-98fc2d13a060


train:   0%|          | 0/1 [00:00<?, ?it/s]

eval:   0%|          | 0/1 [00:00<?, ?it/s]

[I 2024-04-08 23:25:50,641] Trial 0 finished with value: 0.0 and parameters: {'checkpoint': 'whisper-tiny', 'use_augment': False, 'lr': 7.605802989739623e-05, 'momentum': 0.9288056848897888, 'grad_clip': False}. Best is trial 0 with value: 0.0.
[W 2024-04-08 23:25:50,980] Param grad_clip unique value length is less than 2.
[W 2024-04-08 23:25:50,981] Param lr unique value length is less than 2.
[W 2024-04-08 23:25:50,982] Param momentum unique value length is less than 2.
[W 2024-04-08 23:25:50,983] Param use_augment unique value length is less than 2.
[W 2024-04-08 23:25:50,983] Param checkpoint unique value length is less than 2.
[W 2024-04-08 23:25:50,984] Param lr unique value length is less than 2.
[W 2024-04-08 23:25:50,985] Param momentum unique value length is less than 2.
[W 2024-04-08 23:25:50,985] Param use_augment unique value length is less than 2.
[W 2024-04-08 23:25:50,986] Param checkpoint unique value length is less than 2.
[W 2024-04-08 23:25:50,987] Param grad_clip u

train:   0%|          | 0/1 [00:00<?, ?it/s]

eval:   0%|          | 0/1 [00:00<?, ?it/s]

[W 2024-04-08 23:26:10,885] Trial 1 failed with parameters: {'checkpoint': 'whisper-tiny', 'use_augment': True, 'lr': 4.754390426567961e-05, 'momentum': 0.13210065296932683, 'grad_clip': False} because of the following error: AttributeError('Caught AttributeError in DataLoader worker process 0.\nOriginal Traceback (most recent call last):\n  File "C:\\Users\\bsvja\\anaconda3\\envs\\pythonProject\\lib\\site-packages\\torch\\utils\\data\\_utils\\worker.py", line 308, in _worker_loop\n    data = fetcher.fetch(index)\n  File "C:\\Users\\bsvja\\anaconda3\\envs\\pythonProject\\lib\\site-packages\\torch\\utils\\data\\_utils\\fetch.py", line 54, in fetch\n    return self.collate_fn(data)\n  File "C:\\Users\\bsvja\\PycharmProjects\\MangoDemo\\mangoPURE\\models\\collators.py", line 71, in __call__\n    batch = Whisper.extract_features(self.feature_extractor, [x.audio for x in batch_list])\n  File "C:\\Users\\bsvja\\PycharmProjects\\MangoDemo\\mangoPURE\\models\\collators.py", line 71, in <listco

AttributeError: Caught AttributeError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "C:\Users\bsvja\anaconda3\envs\pythonProject\lib\site-packages\torch\utils\data\_utils\worker.py", line 308, in _worker_loop
    data = fetcher.fetch(index)
  File "C:\Users\bsvja\anaconda3\envs\pythonProject\lib\site-packages\torch\utils\data\_utils\fetch.py", line 54, in fetch
    return self.collate_fn(data)
  File "C:\Users\bsvja\PycharmProjects\MangoDemo\mangoPURE\models\collators.py", line 71, in __call__
    batch = Whisper.extract_features(self.feature_extractor, [x.audio for x in batch_list])
  File "C:\Users\bsvja\PycharmProjects\MangoDemo\mangoPURE\models\collators.py", line 71, in <listcomp>
    batch = Whisper.extract_features(self.feature_extractor, [x.audio for x in batch_list])
AttributeError: 'Tensor' object has no attribute 'audio'


In [None]:
checkpoint_to_batch

In [11]:
a, b = get_data({'use_augment': True})

In [15]:
a.mixer.generate()

tensor([-0.0047, -0.0114, -0.0097,  ..., -0.0171, -0.0185, -0.0091])