In [None]:
import math
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from torch import Tensor

# from google.colab import drive
# drive.mount('/content/drive')
# %cd /content/drive/MyDrive/DEEP_LEARNING/KSPnet

from model import Model
from train import train_model
from data import get_dataloaders
from criteria import TimeFrequencyLoss
from preprocess import frame, overlap_and_add
from utils import num_params, plot_waveform, play_audio

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print('device: {}'.format(device))

In [None]:
dataloaders = get_dataloaders('data/train.csv', 'data', batch_size=1,
    max_audio_length=3, validation_split=0.1, device=device)

for inputs, labels, lengths in dataloaders['train']:
    test_noisy = inputs[0]
    test_clean = labels[0]
    break

In [None]:
plot_waveform(test_noisy, 16000)
play_audio(test_noisy, 16000)

In [None]:
plot_waveform(test_clean, 16000)
play_audio(test_clean, 16000)

In [None]:
model = Model(512, d_model=64, in_channels=1, out_channels=1, n_convs=1,
    kernel_size=(1, 3), stride=(1, 2), depth=5, k=2, n_intra=4, n_inter=4,
    n_heads=4, d_hid=1024, dropout=0.1, max_seq_len=1000).to(device)
print(num_params(model))

criterion = TimeFrequencyLoss(alpha=0.4)
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
    factor=0.5, patience=5)

In [None]:
model, losses = train_model(model, criterion, optimizer, scheduler, dataloaders,
    epochs=5, frame_length=512, frame_shift=256, save_checkpoint=True,
    load_checkpoint=True, save_checkpoint_filepath='checkpoint.pth.tar',
    load_checkpoint_filepath='checkpoint.pth.tar', device=device)