In [7]:
import dataloaders.audio_dataset as dataset
import models.inversion_v1 as inversion_model
from abstract_model import AbstractModel

import os

import torch
import torch.nn as nn
from torch import optim

from argparse import Namespace
import pytorch_lightning as pl
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.callbacks import ModelCheckpoint

from ax.plot.contour import plot_contour
from ax.plot.trace import optimization_trace_single_method
from ax.service.managed_loop import optimize
from ax.utils.notebook.plotting import render, init_notebook_plotting

from ax.service.ax_client import AxClient

from pathlib import Path

from ipywidgets import IntProgress

In [8]:
from tensorboard.backend.event_processing import event_accumulator
import numpy as np

def _load_run(path):
    event_acc = event_accumulator.EventAccumulator(path)
    event_acc.Reload()
    data = {}

    for tag in sorted(event_acc.Tags()["scalars"]):
        x, y = [], []

        for scalar_event in event_acc.Scalars(tag):
            x.append(scalar_event.step)
            y.append(scalar_event.value)

        data[tag] = (np.asarray(x), np.asarray(y))
    return data

In [9]:
# SCRATCH = "your/scratch/location"
SCRATCH = "/scratch/prs392"

In [10]:
algo = "inversion_v1"
data_path = f"{SCRATCH}/incubator/data/LibriSpeech/"
checkpoint_path = f"{SCRATCH}/incubator/checkpoints/openl3_librispeech/{algo}/"
experiment_name = "train_and_tune"

In [11]:
d = os.path.join(checkpoint_path, experiment_name)
Path(d).mkdir(parents=True, exist_ok=True)
versions = [o for o in os.listdir(d) if os.path.isdir(os.path.join(d,o))]
versions = sorted(versions)

In [12]:
from pytorch_lightning.core.saving import load_hparams_from_yaml

list_existing_hparams = []
list_of_val_loss = []

for version in versions:
    hparam_path = os.path.join(checkpoint_path, experiment_name, version, 'hparams.yaml')
    hparams_new = load_hparams_from_yaml(hparam_path)
    
    print(hparams_new)
    print(min(_load_run(os.path.join(checkpoint_path, experiment_name, version))['val_loss'][1]))
    hparams_new.pop("train_num_audios", None)    
    hparams_new.pop("val_num_audios", None)    
    hparams_new.pop("test_num_audios", None)    
    
    list_existing_hparams.append(hparams_new)
    list_of_val_loss.append(min(_load_run(os.path.join(checkpoint_path, experiment_name, version))['val_loss'][1]))

list_existing_hparams, list_of_val_loss

([], [])

In [13]:
def train_evaluate(parameterization):
    
    AudioDataset = dataset.AudioDataset
    InversionV1 = inversion_model.InversionV1
    
    data_paths = {}
    data_paths['train'] = os.path.join(data_path, 'train-clean-360')
    data_paths['val'] = os.path.join(data_path, 'dev-clean')
    data_paths['test'] = os.path.join(data_path, 'test-clean')
    
    seed_everything(123)
    
    print(parameterization)
    
    for idx, existing_hparams in enumerate(list_existing_hparams):
        shared_items = {k: existing_hparams[k] for k in existing_hparams if k in parameterization and existing_hparams[k] == parameterization[k]}
        
        if len(existing_hparams) == len(shared_items):
            print("Val loss: " + str(list_of_val_loss[idx]))
            return {'val_loss' : (list_of_val_loss[idx], 0.0)}
        
    parameterization['train_num_audios'] = -1
    parameterization['val_num_audios'] = -1
    parameterization['test_num_audios'] = -1
    parameterization['return_amp'] = True
    parameterization['num_workers'] = 20
    
    hparams = Namespace(**parameterization)
    
    model = AbstractModel(
                hparams=hparams,
                data_paths = data_paths, 
                dataset_model = AudioDataset,
                model = InversionV1(), 
                criterion = nn.MSELoss()
            )
    logger = TensorBoardLogger(checkpoint_path, name=experiment_name)
    
    checkpoint_callback = ModelCheckpoint(
        filepath=None,
        save_top_k=True,
        save_last = True,
        verbose=False,
        monitor='val_loss',
        mode='min',
        prefix=''
    )
    
    if torch.cuda.device_count() == 0:
        trainer = Trainer(
            logger=logger,
            default_root_dir=checkpoint_path,
            checkpoint_callback = checkpoint_callback,
            max_epochs=100,
            check_val_every_n_epoch=1,
            fast_dev_run=False
        )
    else:
        trainer = Trainer(
            logger=logger,
            default_root_dir=checkpoint_path,
            checkpoint_callback = checkpoint_callback,
            gpus = -1,
            distributed_backend='dp',
            max_epochs=100,
            check_val_every_n_epoch=1,
            fast_dev_run=False
        )
        
    trainer.fit(model)
    trainer.test(model)
    print("Val loss: " + str(model.best_validation_loss))
    return {'val_loss' : (model.best_validation_loss, 0.0)}

In [14]:
ax_client = AxClient()
ax_client.create_experiment(
    name="choose_optimizer_scheduler",
    parameters=[
        {"name": "batch_size", "type": "choice", "values": [8, 16, 32]},
        {"name": "lr", "type": "range", "bounds": [1e-6, 0.1], "log_scale": True},
        {"name": "lr_type", "type": "choice", "values": ['adam', 'sgd']},
        {"name": "scheduler_epoch", "type": "choice", "values": [3, 5, 7, 9]},
        {"name": "scheduler_step_size", "type": "range", "bounds": [0.1, 1.0]},
    ],
    objective_name="val_loss",
    minimize=True,
)

[INFO 07-11 22:51:24] ax.service.ax_client: Starting optimization with verbose logging. To disable logging, set the `verbose_logging` argument to `False`. Note that float values in the logs are rounded to 2 decimal points.
[INFO 07-11 22:51:24] ax.modelbridge.dispatch_utils: Using Sobol generation strategy.


In WithDBSettings, db settings: None


In [15]:
total_number_of_trials = 100

for params in list_existing_hparams:
    parameters, trial_index = ax_client.attach_trial(params)
    ax_client.complete_trial(trial_index=trial_index, raw_data=train_evaluate(parameters))
    total_number_of_trials -= 1
    
custom_parameters = [
    {
        'batch_size': 256,
        'lr': 0.0004136762567284789,
        'lr_type': 'sgd',
        'scheduler_epoch': 7,
        'scheduler_step_size': 0.9465236907824874,
    }
]
    
for params in custom_parameters:
    parameters, trial_index = ax_client.attach_trial(params)
    ax_client.complete_trial(trial_index=trial_index, raw_data=train_evaluate(parameters))
    total_number_of_trials -= 1
    
for _ in range(total_number_of_trials):        
    parameters, trial_index = ax_client.get_next_trial()
    ax_client.complete_trial(trial_index=trial_index, raw_data=train_evaluate(parameters))

[INFO 07-11 22:51:25] ax.service.ax_client: Attached custom parameterization {'batch_size': 8, 'lr': 0.0, 'lr_type': 'sgd', 'scheduler_epoch': 7, 'scheduler_step_size': 0.95} as trial 0.


{'batch_size': 8, 'lr': 0.0004136762567284789, 'lr_type': 'sgd', 'scheduler_epoch': 7, 'scheduler_step_size': 0.9465236907824874}


GPU available: True, used: True
TPU available: False, using: 0 TPU cores
CUDA_VISIBLE_DEVICES: [0,1,2]
Set SLURM handle signals.

  | Name      | Type        | Params
------------------------------------------
0 | model     | InversionV1 | 7 M   
1 | criterion | MSELoss     | 0     


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x2b5171ab2d40>
Traceback (most recent call last):
  File "/ext3/miniconda3/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 962, in __del__
    self._shutdown_workers()
  File "/ext3/miniconda3/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 942, in _shutdown_workers
    w.join()
  File "/ext3/miniconda3/lib/python3.7/multiprocessing/process.py", line 140, in join
    res = self._popen.wait(timeout)
  File "/ext3/miniconda3/lib/python3.7/multiprocessing/popen_fork.py", line 48, in wait
    return self.poll(os.WNOHANG if timeout == 0.0 else 0)
  File "/ext3/miniconda3/lib/python3.7/multiprocessing/popen_fork.py", line 28, in poll
    pid, sts = os.waitpid(self.pid, flag)
KeyboardInterrupt: 
Exception ignored in: 




<function WeakValueDictionary.__init__.<locals>.remove at 0x2b517170fdd0>
Traceback (most recent call last):
  File "/ext3/miniconda3/lib/python3.7/weakref.py", line 109, in remove
    def remove(wr, selfref=ref(self), _atomic_removal=_remove_dead_weakref):
KeyboardInterrupt

Detected KeyboardInterrupt, attempting graceful shutdown...



KeyboardInterrupt: 