In [None]:
import os
import json
import time

import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm

import torch
import torch.nn as nn

import display
from dataset.libri import load_data
from models import FCAE, CDAE, UNet
from train import set_device, train, evaluate

# Set compute device
device = set_device(verbose=True)

## Fully-Connected Autoencoder

In [None]:
params = {
    'network': FCAE,
    'data': {
        'N': 10,
        'test_size': .10,
        'data_root': 'data/noised_synth_babble',
        'libri_root': 'data/LibriSpeech/dev-clean',
        'batch_size': 8,
        'pin_memory': (device == 'cuda'),
        'conv': False,
        'seed': 1,
        'srate': 16000
    },
    'model': {
        'in_shape': (256, 256),
        'n_layers': 4,
        'z_dim': 8
    },
    'train': {
        'epochs': 2,
        'learning_rate': 0.001,
        'criterion': nn.MSELoss()
    }
}

## Convolutional Autoencoder

In [None]:
params = {
    'network': CDAE,
    'data': {
        'N': 10,
        'test_size': .10,
        'data_root': 'data/noised_synth_babble',
        'libri_root': 'data/LibriSpeech/dev-clean',
        'batch_size': 8,
        'pin_memory': (device == 'cuda'),
        'conv': True,
        'seed': 1,
        'srate': 16000
    },
    'model': {
        'n_layers': 4,
        'z_dim': 8,
        'in_channels': 1,
        'batch_norm': True
    },
    'train': {
        'epochs': 2,
        'learning_rate': 0.001,
        'criterion': nn.BCELoss()
    }
}

## U-Net

In [None]:
params = {
    'network': UNet,
    'data': {
        'N': 10,
        'test_size': .10,
        'data_root': 'data/noised_synth_babble',
        'libri_root': 'data/LibriSpeech/dev-clean',
        'batch_size': 8,
        'pin_memory': (device == 'cuda'),
        'conv': True,
        'seed': 1,
        'srate': 16000
    },
    'model': {
        'in_shape': (256, 256),
        'in_channels': 1,
        'n_classes': 1,
        'encoder_channels': (4, 8, 16),
        'decoder_channels': (16, 8, 4),
        'retain_dim': True
    },
    'train': {
        'epochs': 2,
        'learning_rate': 0.001,
        'criterion': nn.BCEWithLogitsLoss()
    }
}

In [None]:
model = params['network'](**params['model']).to(device)
print(model)
print('---')

print('\nLoading data...\n')
data_train, train_dl, data_val, val_dl, data_test = load_data(**params['data'])
display.show_split_sizes((data_train, data_val, data_test))
   
print('\nTraining model...\n')
model, hist = train(device, model, train_dl, val_dl, **params['train'])

# Plot Losses
fig, ax = plt.subplots(figsize=(10, 5))
ax = display.plot_losses(ax, hist, repr(params['train']['criterion']))
fig.show()

# Evaluate Model
fig, axes = evaluate(device, model, data_test)
fig.show()