In [1]:
import dataloaders.audio_dataset as dataset
import models.inversion_v1 as inversion_model
from abstract_model import AbstractModel

import torch
import torch.nn as nn
from torch import optim

from argparse import Namespace

import pytorch_lightning as pl
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning import Trainer, seed_everything

from ipywidgets import IntProgress

In [2]:
AudioDataset = dataset.AudioDataset
InversionV1 = inversion_model.InversionV1

data_paths = {}
data_paths['train'] = '/scratch/prs392/incubator/data/LibriSpeech/train-clean-360'
data_paths['val'] = '/scratch/prs392/incubator/data/LibriSpeech/dev-clean'
data_paths['test'] = '/scratch/prs392/incubator/data/LibriSpeech/test-clean'

In [3]:
# SCRATCH = "your/scratch/location"
SCRATCH = "/scratch/prs392"

In [4]:
algo = "inversion_v1"
data_path = f"{SCRATCH}/incubator/data/LibriSpeech/"
checkpoint_path = f"{SCRATCH}/incubator/checkpoints/openl3_librispeech/{algo}/"
experiment_name = "train_with_specific_hparams"

In [5]:
audio_dataset = AudioDataset(root_dir=data_paths['train'], num_audios = 10, return_amp = True)

for i in range(len(audio_dataset)):
    emb, spec, j = audio_dataset[i]
    print("Embeddings shape: " + str(emb.shape))
    print("Expected Spectrogram shape: " + str(spec.shape))
    m = InversionV1()
    pred = m(emb)
    print("Predicted Spectrogram shape: " + str(pred.shape))
    if i == 0:
        break

Embeddings shape: torch.Size([6144])
Expected Spectrogram shape: torch.Size([1, 128, 199])
Predicted Spectrogram shape: torch.Size([1, 1, 128, 199])


In [10]:
args = {
    'batch_size': 8, # Compulsory
    'lr': 0.0002, # Compulsory
    'scheduler_epoch': 3, # Compulsory
    'scheduler_step_size': 0.1, # Compulsory
    'lr_type': 'adam',
}

args['train_num_audios'] = 10
args['val_num_audios'] = 10
args['test_num_audios'] = 10
args['return_amp'] = True
args['num_workers'] = 0

hparams = Namespace(**args)

In [11]:
hparams.lr

0.0002

In [None]:
seed_everything(123)

model = AbstractModel(
            hparams=hparams,
            data_paths = data_paths, 
            dataset_model = AudioDataset,
            model = InversionV1(), 
            criterion = nn.MSELoss()
        )

logger = TensorBoardLogger(checkpoint_path, name=experiment_name)
    
checkpoint_callback = ModelCheckpoint(
    filepath=None,
    save_top_k=True,
    save_last = True,
    verbose=False,
    monitor='val_loss',
    mode='min',
    prefix=''
)

if torch.cuda.device_count() == 0:
    print('cpu')
    trainer = Trainer(
        logger=logger,
        default_root_dir=checkpoint_path,
        checkpoint_callback = checkpoint_callback,
        max_epochs=100,
        check_val_every_n_epoch=1,
        fast_dev_run=False
    )
else:
    trainer = Trainer(
        logger=logger,
        default_root_dir=checkpoint_path,
        checkpoint_callback = checkpoint_callback,
        gpus = -1,
        distributed_backend='dp',
        max_epochs=100,
        check_val_every_n_epoch=1,
        fast_dev_run=False
    )

trainer.fit(model)
trainer.test(model)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores


cpu


Set SLURM handle signals.

  | Name      | Type        | Params
------------------------------------------
0 | model     | InversionV1 | 7 M   
1 | criterion | MSELoss     | 0     


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…






Set SLURM handle signals.


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Testing', layout=Layout(flex='2'), max=…