# Set up model and hyperparameters

In [1]:
from models.WaveNetVAE.WaveVae import WaveNetVAE
from models.WaveNetVAE.WVData import WVDataset
import torch
from torch.utils.data import DataLoader

"""
Hyperparameters
"""

learning_rate = 0.0001
batchsize = 4
device='cuda'
input_size = (40, 112)
upsamples = [2, 2, 2, 2, 2, 2, 2, 2]
zsize = 32

WaveVAE = WaveNetVAE(input_size,
                     num_hiddens = 768,
                     upsamples = upsamples,
                     zsize = zsize)

WaveVAE.to(device)

VAEDataset = WVDataset(audio_path = "../ConvDenoiser/testdatawav",
                       length = 4096,
                       sample_rate = 24000,
                       hop_length = 128)

val_VAEDataset = WVDataset(audio_path = "../ConvDenoiser/valdatawav",
                       length = 4096,
                       sample_rate = 24000,
                       hop_length = 128)

VAEDataloader = DataLoader(VAEDataset,
                           batch_size = batchsize,
                           shuffle = True)

val_VAEDataloader = DataLoader(val_VAEDataset,
                           batch_size = batchsize,
                           shuffle = False)

WaveNet Receptive Field:  4093


Loading and preprocessing files to dataset.:   0%|          | 0/29 [00:00<?, ?it/s]

4096


Loading and preprocessing files to dataset.: 100%|██████████| 29/29 [00:19<00:00,  1.51it/s]


4096


Loading and preprocessing files to dataset.: 100%|██████████| 4/4 [00:04<00:00,  1.24s/it]


# Test tensor

In [7]:
onehot, mfcc, target = next(iter(VAEDataloader))
onehot = onehot.to(device)
mfcc = mfcc.to(device)
target = target.to(device)

print("Trying tensors with sizes:")
print("Onehot size: ", onehot.size(), "| MFCC size: ", mfcc.size(), "| Target size: ", target.size())

output = WaveVAE(onehot, mfcc, True)
print("Tensors passed through model succesfully")
loss = loss_fn(output[:, :, -1], target)
print("Loss function output: ", loss)

Onehot size:  torch.Size([4, 256, 4096]) | MFCC size:  torch.Size([4, 40, 33]) | Target size:  torch.int64
Before downsample:  torch.Size([4, 768, 33])
After downsample:  torch.Size([4, 768, 16])
torch.Size([4, 256, 16])
torch.Size([4, 256, 4096]) torch.Size([4, 256, 4096])
torch.float32


# Start training

In [2]:
from models.WaveNetVAE.train import train
from torch.utils.tensorboard import SummaryWriter
import warnings
warnings.filterwarnings("ignore")
writer = SummaryWriter()

train(WaveVAE, VAEDataloader, val_VAEDataloader, 
      writer = writer, 
      learning_rate = learning_rate,
      epoch_amount = 100,
      logs_per_epoch = 5,
      kl_anneal = 0.01,
      max_kl = 0.5,
      device = device)


  return F.conv1d(input, weight, bias, self.stride,
Validating. Rec loss: 5.61.: 100%|██████████| 58/58 [00:44<00:00,  1.30it/s]02:23<09:09,  2.08s/it]
Validating. Rec loss: 5.72.: 100%|██████████| 58/58 [00:41<00:00,  1.39it/s] [05:26<06:48,  2.05s/it]
Validating. Rec loss: 5.67.: 100%|██████████| 58/58 [00:39<00:00,  1.48it/s] [08:27<04:30,  2.02s/it]
Validating. Rec loss: 5.47.: 100%|██████████| 58/58 [00:40<00:00,  1.43it/s] [11:33<02:40,  2.32s/it]
Validating. Rec loss: 5.51.: 100%|██████████| 58/58 [00:39<00:00,  1.48it/s] [14:31<00:08,  2.09s/it]
Validating. Rec/real loss for step 327: 5.74/5.74.: 100%|██████████| 327/327 [15:16<00:00,  2.80s/it]
Validating. Rec loss: 5.52.:  24%|██▍       | 14/58 [00:14<00:46,  1.05s/it]02:21<10:40,  2.43s/it]
Validating. Rec/real loss for step 64: 5.43/5.43.:  19%|█▉        | 63/327 [02:36<10:55,  2.48s/it]


KeyboardInterrupt: 