In [None]:
import os
import json
import time
import pickle

import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm

import torch
import torch.nn as nn

import display
from dataset.libri import load_data
from models import FCAE, CDAE, UNet
import train

# Set compute device
device = train.set_device(verbose=True)

## Training/Evaluation

In [None]:
config_path = 'configs/fcae.json'
params = train.load_params_from_config(config_path, device)
params['data']['N'] = 1
params['train']['epochs'] = 20
params['train']['learning_rate'] = 0.001

model = params['network'](**params['model']).to(device)
print(model)
print('---')

print('\nLoading data...\n')
data_train, train_dl, data_val, val_dl, data_test = load_data(
    **params['data'])
display.show_split_sizes((data_train, data_val, data_test))

print('\nTraining model...\n')
model, hist = train.train(
    device, model, params['name'],
    train_dl, val_dl,
    **params['train'])

print('\nLoading trained model...\n')
model = params['network'](**params['model']).to(device)
model.load_state_dict(torch.load(
    hist['model_path'], map_location=device))

# Plot Losses
fig, ax = plt.subplots(figsize=(10, 5))
ax = display.plot_losses(ax, hist, repr(params['train']['criterion']))
fig.show()

# Evaluate Model
fig, axes = train.evaluate(device, model, data_test)
fig.show()