# Set up model and hyperparameters

In [1]:
from models.WaveNetVAE.WaveVae import WaveNetVAE
from models.WaveNetVAE.WVData import WVDataset
import torch
from torch.utils.data import DataLoader

"""
Hyperparameters
"""

learning_rate = 0.000001
batchsize = 4
device='cuda'
input_size = (40, 112)
upsamples = [2, 2, 2, 2, 2, 2, 2, 2]
zsize = 32

WaveVAE = WaveNetVAE(input_size,
                     num_hiddens = 768,
                     upsamples = upsamples,
                     zsize = zsize)

WaveVAE.to(device)

VAEDataset = WVDataset(audio_path = "../ConvDenoiser/testdatawav",
                       length = 4096,
                       sample_rate = 24000,
                       hop_length = 128)

val_VAEDataset = WVDataset(audio_path = "../ConvDenoiser/testdatawav",
                       length = 4096,
                       sample_rate = 24000,
                       hop_length = 128)

VAEDataloader = DataLoader(VAEDataset,
                           batch_size = batchsize,
                           shuffle = True)

val_VAEDataloader = DataLoader(VAEDataset,
                           batch_size = batchsize,
                           shuffle = False)

optimizer = torch.optim.AdamW(WaveVAE.parameters(), lr=learning_rate)
loss_fn = torch.nn.CrossEntropyLoss()

WaveNet Receptive Field:  4093


Loading and preprocessing files to dataset.:   0%|          | 0/33 [00:00<?, ?it/s]

4096


  raw, _ = librosa.load(filename, sampling_rate, res_type=res_type)
  raw, _ = librosa.load(filename, sampling_rate, res_type=res_type)
Loading and preprocessing files to dataset.: 100%|██████████| 33/33 [00:12<00:00,  2.63it/s]


# Test tensor

In [7]:
onehot, mfcc, target = next(iter(VAEDataloader))
onehot = onehot.to(device)
mfcc = mfcc.to(device)
target = target.to(device)

print("Trying tensors with sizes:")
print("Onehot size: ", onehot.size(), "| MFCC size: ", mfcc.size(), "| Target size: ", target.size())

output = WaveVAE(onehot, mfcc, True)
print("Tensors passed through model succesfully")
loss = loss_fn(output[:, :, -1], target)
print("Loss function output: ", loss)

Onehot size:  torch.Size([4, 256, 4096]) | MFCC size:  torch.Size([4, 40, 33]) | Target size:  torch.int64
Before downsample:  torch.Size([4, 768, 33])
After downsample:  torch.Size([4, 768, 16])
torch.Size([4, 256, 16])
torch.Size([4, 256, 4096]) torch.Size([4, 256, 4096])
torch.float32


# Start training

In [None]:
from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter()
torch.cuda.empty_cache()

logstep = 0
logs_per_epoch = 5
epoch_amount = 100
kl_anneal = 0.1
kl_mult = 0.0
total_step = 0

def calculate_loss(output, target, mean, variance, kl_term):
    reconstruction_loss = loss_fn(output[:, :, -1], target)
    kl_loss = - 0.5 * torch.mean(1+ variance - mean.pow(2) - variance.exp())

    return reconstruction_loss + kl_loss * kl_term, reconstruction_loss, kl_loss

def anneal_kl(kl_term, kl_annealing):
    kl_term += kl_annealing

    return kl_term

for epoch in range(epoch_amount):
    WaveVAE.train(True)
    total_epoch_loss = [0, 0, 0]
    step = 1

    with tqdm(enumerate(VAEDataloader),total=len(VAEDataloader),desc=f"Training. Epoch: {epoch}. Loss for step {step}: n.v.t.") as t:
        for batch_idx, (onehot_input, mfcc_input, target) in t:
            optimizer.zero_grad(set_to_none=True)

            onehot_input = onehot_input.to(device)
            mfcc_input = mfcc_input.to(device)
            target = target.to(device)

            output, mean, variance = WaveVAE(onehot_input, target, True)

            real_loss, rec_loss, kl_loss = calculate_loss(output, target, mean, variance, kl_mult)
            real_loss.backward()
            optimizer.step()

            # Save losses for total, reconstruction and kl seperately for better inspection of optimisation for different parts
            total_epoch_loss = [
                total_epoch_loss[0] + real_loss.item(),
                total_epoch_loss[1] + rec_loss.item(),
                total_epoch_loss[2] + kl_loss.item()
                ]

            t.set_description(f"Validating. Average rec/real loss for step {step}: {round(rec_loss.item(), 2)}/{round(real_loss.item(), 2)}.")
            writer.add_scalar('Train step loss:', real_loss.item(), total_step)
            step += 1
            total_step += 1

            if step % (len(VAEDataloader) // logs_per_epoch) == 0 or step - 1 == 0:

                writer.add_scalars('Train loss', {
                                        'Real loss': total_epoch_loss[0] / step,
                                        'Reconstruction loss': total_epoch_loss[1] / step,
                                        'Kl loss': total_epoch_loss[2] / step
                                    }, logstep)
                
                logstep += 1
            
