In [19]:
import dataloaders.audio_dataset as dataset
import models.inversion_v3_stacked as inversion_model
from abstract_model import AbstractModel

import torch
import torch.nn as nn
from torch import optim

import os

from argparse import Namespace
from pathlib import Path

import pytorch_lightning as pl
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning import Trainer, seed_everything

from ax.plot.contour import plot_contour
from ax.plot.trace import optimization_trace_single_method
from ax.service.managed_loop import optimize
from ax.utils.notebook.plotting import render, init_notebook_plotting

from tensorboard.backend.event_processing import event_accumulator

from ax.service.ax_client import AxClient

import numpy as np

from ipywidgets import IntProgress

In [20]:
def _load_run(path):
    event_acc = event_accumulator.EventAccumulator(path)
    event_acc.Reload()
    data = {}

    for tag in sorted(event_acc.Tags()["scalars"]):
        x, y = [], []

        for scalar_event in event_acc.Scalars(tag):
            x.append(scalar_event.step)
            y.append(scalar_event.value)

        data[tag] = (np.asarray(x), np.asarray(y))
    return data

In [21]:
AudioDataset = dataset.AudioDataset
InversionV3 = inversion_model.InversionV3

data_paths = {}
data_paths['train'] = '/scratch/prs392/incubator/data/LibriSpeech/speaker_identification/train'
data_paths['val'] = '/scratch/prs392/incubator/data/LibriSpeech/speaker_identification/val'
data_paths['test'] = '/scratch/prs392/incubator/data/LibriSpeech/speaker_identification/test'

In [38]:
# SCRATCH = "your/scratch/location"
SCRATCH = "/scratch/prs392"
algo = "inversion_v3_sketch_identification"
data_path = f"{SCRATCH}/incubator/data/LibriSpeech/speaker_identification"
checkpoint_path = f"{SCRATCH}/incubator/checkpoints/openl3_librispeech/{algo}/"
experiment_name = "tuning"

In [39]:
checkpoint_path

'/scratch/prs392/incubator/checkpoints/openl3_librispeech/inversion_v3_sketch_identification/'

In [40]:
Path(checkpoint_path).mkdir(parents=True, exist_ok=True)

In [41]:
audio_dataset = AudioDataset(root_dir=data_paths['train'], num_audios = 10, return_amp = True)

for i in range(len(audio_dataset)):
    emb, spec, audio_prep, file_name, j = audio_dataset[i]
    print("Embeddings shape: " + str(emb.shape))
    print("Expected Spectrogram shape: " + str(spec.shape))
    m = InversionV3()
    pred = m(emb)
    print("Predicted Spectrogram shape: " + str(pred.shape))
    if i == 0:
        break

Embeddings shape: torch.Size([6144])
Expected Spectrogram shape: torch.Size([1, 128, 199])
Predicted Spectrogram shape: torch.Size([1, 1, 128, 199])


In [42]:
d = os.path.join(checkpoint_path, experiment_name)
Path(d).mkdir(parents=True, exist_ok=True)
versions = [o for o in os.listdir(d) if os.path.isdir(os.path.join(d,o))]
versions = sorted(versions)

In [48]:
from pytorch_lightning.core.saving import load_hparams_from_yaml

list_existing_hparams = []
list_of_val_loss = []

for version in versions:
    hparam_path = os.path.join(checkpoint_path, experiment_name, version, 'hparams.yaml')
    hparams_new = load_hparams_from_yaml(hparam_path)
    
    print(hparams_new)
    print(min(_load_run(os.path.join(checkpoint_path, experiment_name, version))['val_loss'][1]))
    hparams_new.pop("return_amp", None)    
    hparams_new.pop("num_workers", None)
    hparams_new.pop("num_frames", None)
    hparams_new.pop("emb_means", None)
    hparams_new.pop("emb_stds", None)
    hparams_new.pop("spec_means", None)
    hparams_new.pop("spec_stds", None)
    hparams_new.pop("test_num_audios", None)
    hparams_new.pop("train_num_audios", None)
    hparams_new.pop("val_num_audios", None)
    
    list_existing_hparams.append(hparams_new)
    list_of_val_loss.append(min(_load_run(os.path.join(checkpoint_path, experiment_name, version))['val_loss'][1]))

{'batch_size': 16, 'emb_means': None, 'emb_stds': None, 'lr': 0.0026649732007104645, 'lr_type': 'adam', 'num_frames': -1, 'num_workers': 5, 'return_amp': True, 'scheduler_epoch': 5, 'scheduler_step_size': 0.19578851433470845, 'spec_means': None, 'spec_stds': None, 'test_num_audios': -1, 'train_num_audios': -1, 'val_num_audios': -1}
0.007523973472416401


In [49]:
list_existing_hparams, list_of_val_loss

([{'batch_size': 16,
   'lr': 0.0026649732007104645,
   'lr_type': 'adam',
   'scheduler_epoch': 5,
   'scheduler_step_size': 0.19578851433470845}],
 [0.007523973472416401])

In [55]:
def train_evaluate(parameterization):
    
    AudioDataset = dataset.AudioDataset
    InversionV3 = inversion_model.InversionV3
    
    data_paths = {}
    data_paths['train'] = os.path.join(data_path, 'train')
    data_paths['val'] = os.path.join(data_path, 'val')
    data_paths['test'] = os.path.join(data_path, 'test')
    
    seed_everything(123)
    
    print(parameterization)
    
    for idx, existing_hparams in enumerate(list_existing_hparams):
        shared_items = {k: existing_hparams[k] for k in existing_hparams if k in parameterization and existing_hparams[k] == parameterization[k]}
        
        if len(existing_hparams) == len(shared_items):
            print("Val loss: " + str(list_of_val_loss[idx]))
            return {'val_loss' : (list_of_val_loss[idx], 0.0)}
        
    parameterization['return_amp'] = True
    parameterization['num_workers'] = 7
    
    hparams = Namespace(**parameterization)
    
    model = AbstractModel(
                hparams=hparams,
                data_paths = data_paths, 
                dataset_model = AudioDataset,
                model = InversionV3(), 
                criterion = nn.MSELoss()
            )

    logger = TensorBoardLogger(checkpoint_path, name=experiment_name)

    checkpoint_callback = ModelCheckpoint(
        filepath=None,
        save_top_k=True,
        save_last = True,
        verbose=False,
        monitor='val_loss',
        mode='min',
        prefix=''
    )

    if torch.cuda.device_count() == 0:
        print('cpu')
        trainer = Trainer(
            logger=logger,
            default_root_dir=checkpoint_path,
            checkpoint_callback = checkpoint_callback,
            row_log_interval=50,
            log_save_interval=500,
            val_check_interval=0.10,
            max_epochs=2,
            fast_dev_run=False
        )
    else:
        trainer = Trainer(
            logger=logger,
            default_root_dir=checkpoint_path,
            checkpoint_callback = checkpoint_callback,
            row_log_interval=50,
            log_save_interval=500,
            val_check_interval=0.10,
            gpus = -1,
            distributed_backend='dp',
            max_epochs=2,
            fast_dev_run=False
        )

    trainer.fit(model)
    trainer.test(model)
    print("Val loss: " + str(model.best_validation_loss))
    return {'val_loss' : (model.best_validation_loss, 0.0)}

In [56]:
ax_client = AxClient()
ax_client.create_experiment(
    name="choose_optimizer_scheduler",
    parameters=[
        {"name": "batch_size", "type": "choice", "values": [16, 32]},
        {"name": "lr", "type": "range", "bounds": [1e-6, 0.1], "log_scale": True},
        {"name": "lr_type", "type": "choice", "values": ['adam', 'sgd']},
        {"name": "scheduler_epoch", "type": "choice", "values": [3, 5, 7, 9]},
        {"name": "scheduler_step_size", "type": "range", "bounds": [0.1, 1.0]},
    ],
    objective_name="val_loss",
    minimize=True,
)

[INFO 08-05 20:56:22] ax.service.ax_client: Starting optimization with verbose logging. To disable logging, set the `verbose_logging` argument to `False`. Note that float values in the logs are rounded to 2 decimal points.
[INFO 08-05 20:56:22] ax.modelbridge.dispatch_utils: Using Sobol generation strategy.


In WithDBSettings, db settings: None


In [None]:
total_number_of_trials = 50

for params in list_existing_hparams:
    parameters, trial_index = ax_client.attach_trial(params)
    ax_client.complete_trial(trial_index=trial_index, raw_data=train_evaluate(parameters))
    total_number_of_trials -= 1
    
for _ in range(total_number_of_trials - len(list_existing_hparams)):        
    parameters, trial_index = ax_client.get_next_trial()
    ax_client.complete_trial(trial_index=trial_index, raw_data=train_evaluate(parameters))
    

[INFO 08-05 20:56:47] ax.service.ax_client: Attached custom parameterization {'batch_size': 16, 'lr': 0.0, 'lr_type': 'adam', 'scheduler_epoch': 5, 'scheduler_step_size': 0.2} as trial 4.
[INFO 08-05 20:56:47] ax.service.ax_client: Completed trial 4 with data: {'val_loss': (0.01, 0.0)}.
[INFO 08-05 20:56:47] ax.service.ax_client: Generated new trial 5 with parameters {'lr': 0.0, 'scheduler_step_size': 0.73, 'batch_size': 32, 'lr_type': 'sgd', 'scheduler_epoch': 7}.


{'batch_size': 16, 'lr': 0.0026649732007104645, 'lr_type': 'adam', 'scheduler_epoch': 5, 'scheduler_step_size': 0.19578851433470845}
Val loss: 0.007523973472416401
{'lr': 1.8363759499907522e-05, 'scheduler_step_size': 0.7303553459234535, 'batch_size': 32, 'lr_type': 'sgd', 'scheduler_epoch': 7}


GPU available: True, used: True
INFO:lightning:GPU available: True, used: True
TPU available: False, using: 0 TPU cores
INFO:lightning:TPU available: False, using: 0 TPU cores
CUDA_VISIBLE_DEVICES: [0,1,2,3,4]
INFO:lightning:CUDA_VISIBLE_DEVICES: [0,1,2,3,4]
Set SLURM handle signals.
INFO:lightning:Set SLURM handle signals.

  | Name      | Type        | Params
------------------------------------------
0 | model     | InversionV3 | 16 M  
1 | criterion | MSELoss     | 0     
INFO:lightning:
  | Name      | Type        | Params
------------------------------------------
0 | model     | InversionV3 | 16 M  
1 | criterion | MSELoss     | 0     


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…