In [None]:
import warnings
warnings.filterwarnings('ignore')

import torchaudio
import torchaudio.transforms as T
from torch.utils.data import Dataset, DataLoader

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import pytorch_lightning as pl
from pytorch_lightning.metrics.functional import accuracy
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning import loggers as pl_loggers

from project.data import UrbanSoundDataset
from project.model import LitCRNN

In [None]:
device = (torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu'))
print(f"Training on device {device}")

tb_logger = pl_loggers.TensorBoardLogger('../logs/')

csv_path = '../datasets/UrbanSound8K/metadata/UrbanSound8K.csv'
file_path = '../datasets/UrbanSound8K/audio/'

# transform to log_amplitude mel-spectrogram
# using original sampling_rate, n_fft, hop_length, n_mels from the paper
transform = nn.Sequential(
    T.MelSpectrogram(sample_rate=12000, n_fft=512, hop_length=256, n_mels=96),
    T.AmplitudeToDB()
)

# device train_folds and val_folds
folds = list(range(1,11))
val_folds = [10] # CHANGE HERE FOR A DIFFERENT VALIDATION FOLD!
train_folds = [fold for fold in folds if fold not in val_folds]

# create train and test sets using chosen transform
train_set = UrbanSoundDataset(csv_path, file_path, train_folds, transform=transform)
val_set = UrbanSoundDataset(csv_path, file_path, val_folds, transform=transform)
print("Train set size: " + str(len(train_set)))
print("val set size: " + str(len(val_set)))

train_loader = DataLoader(train_set, batch_size = 12, shuffle = True, num_workers=20, pin_memory=True)
val_loader = DataLoader(val_set, batch_size = 12, num_workers=20, pin_memory=True)

In [None]:
early_stop_callback = EarlyStopping(
   monitor='val_acc',
   min_delta=0.00,
   patience=5,
   verbose=True,
   mode='max'
)

checkpoint_callback = ModelCheckpoint(
    monitor='val_acc',
    dirpath='../weights',
    filename='fold-10-{epoch:02d}-{val_acc:.2f}',
    save_top_k=3,
    mode='max'
)

model = LitCRNN()
trainer = pl.Trainer(gpus=1, max_epochs=20, progress_bar_refresh_rate=20,
                    callbacks=[checkpoint_callback, early_stop_callback],
                    logger=tb_logger)

In [None]:
trainer.fit(model, train_loader, val_loader)