In [1]:
import os, argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
from modules.model import Model
from modules.loss import MDNLoss
import hparams
from text import *
from utils.utils import *
from utils.writer import get_writer
from torch.utils.tensorboard import SummaryWriter

os.environ["CUDA_VISIBLE_DEVICES"]='0'
torch.manual_seed(hparams.seed)
torch.cuda.manual_seed(hparams.seed)
    
train_loader, val_loader, collate_fn = prepare_dataloaders(hparams)
model = Model(hparams).cuda()
criterion = MDNLoss()
optimizer = torch.optim.Adam(model.parameters(),
                             lr=hparams.lr,
                             betas=(0.9, 0.98),
                             eps=1e-09)

In [2]:
iteration, loss = 0, 0
model.train()
print("Training Start!!!")
for i, batch in enumerate(train_loader):
    text_padded, text_lengths, mel_padded, mel_lengths = [
        x.cuda() for x in batch
    ]
    mel_padded = (mel_padded - torch.min(mel_padded))/torch.max((mel_padded - torch.min(mel_padded)))

    mu, sigma = model(text_padded,
                     mel_padded,
                     None,
                     text_lengths,
                     mel_lengths,
                     criterion,
                     stage=0)
    break
    sub_loss = mdn_loss.mean()/hparams.accumulation
    sub_loss.backward()
    loss = loss+sub_loss.item()

    iteration += 1
    if iteration%hparams.accumulation == 0:
        lr_scheduling(optimizer, iteration//hparams.accumulation)
        nn.utils.clip_grad_norm_(model.parameters(), hparams.grad_clip_thresh)
        optimizer.step()
        model.zero_grad()
        loss=0

Training Start!!!


  ids = lengths.new_tensor(torch.arange(0, max_len))


In [3]:
mu_sigma = torch.cat([mu.squeeze(2), sigma.squeeze(2)], dim=-1)
B, L, F = mu_sigma.size()
T = mel_padded.size(2)

x = mel_padded.transpose(1,2).unsqueeze(1) # B, 1, T, F
mu = mu_sigma[:, :, :hparams.n_mel_channels].unsqueeze(2) # B, L, 1, F
sigma = mu_sigma[:, :, hparams.n_mel_channels:].unsqueeze(2) # B, L, 1, F

In [4]:
exponential = -0.5*torch.sum((x-mu)*(x-mu)/sigma**2, dim=-1) # B, L, T
coef = (2*3.14)**(hparams.n_mel_channels/2) * torch.prod(sigma, dim=-1)**0.5 # B, L, 1

prob_matrix = torch.exp(exponential) / coef # B, L, T

In [None]:
(2*3.14)**(hparams.n_mel_channels/2)

In [7]:
prob_matrix

tensor([[[1.7878e-27, 1.8860e-25, 7.1307e-25,  ..., 5.9631e-32,
          3.4582e-32, 5.7192e-32],
         [2.2234e-26, 2.3228e-25, 3.0743e-25,  ..., 4.6142e-31,
          5.9364e-31, 5.4127e-31],
         [1.9233e-25, 2.6044e-25, 4.1108e-25,  ..., 2.6783e-28,
          1.6611e-28, 1.9354e-28],
         ...,
         [4.5936e-30, 1.0362e-28, 5.3460e-28,  ..., 3.5198e-33,
          1.5913e-33, 1.0099e-32],
         [5.0251e-30, 1.5029e-27, 1.8427e-26,  ..., 6.8919e-37,
          1.6213e-36, 1.3485e-36],
         [1.2304e-26, 3.4397e-25, 2.1573e-25,  ..., 3.4750e-31,
          4.7190e-31, 1.0223e-30]],

        [[1.4813e-28, 1.5204e-24, 1.8643e-25,  ..., 2.5170e-40,
          2.5170e-40, 2.5170e-40],
         [8.8440e-26, 2.8244e-22, 4.9172e-23,  ..., 4.7224e-43,
          4.7224e-43, 4.7224e-43],
         [5.4583e-25, 2.1930e-22, 1.5412e-23,  ..., 2.0067e-42,
          2.0067e-42, 2.0067e-42],
         ...,
         [1.7104e-26, 4.5980e-24, 8.1279e-25,  ..., 4.7171e-39,
          4.717

#    
#    
#    
#    
#    
#    
#    
#    

In [None]:
def validate(model, criterion, val_loader, iteration, writer):
    model.eval()
    with torch.no_grad():
        n_data, val_loss = 0, 0
        for i, batch in enumerate(val_loader):
            n_data += len(batch[0])
            text_padded, text_lengths, mel_padded, mel_lengths, gate_padded = [
                x.cuda() for x in batch
            ]
            
            mel_out, mel_out_post,\
            enc_alignments, dec_alignments, enc_dec_alignments, gate_out = model.module.outputs(text_padded,
                                                                                                mel_padded,
                                                                                                text_lengths,
                                                                                                mel_lengths)
        
            mel_loss, bce_loss, guide_loss = criterion((mel_out, mel_out_post, gate_out),
                                                       (mel_padded, gate_padded),
                                                       (enc_dec_alignments, text_lengths, mel_lengths))
            
            loss = torch.mean(mel_loss+bce_loss+guide_loss)
            val_loss += loss.item() * len(batch[0])

        val_loss /= n_data

    writer.add_losses(mel_loss.item(),
                      bce_loss.item(),
                      guide_loss.item(),
                      iteration//hparams.accumulation, 'Validation')
    
    writer.add_specs(mel_padded.detach().cpu(),
                     mel_out.detach().cpu(),
                     mel_out_post.detach().cpu(),
                     mel_lengths.detach().cpu(),
                     iteration//hparams.accumulation, 'Validation')
    
    writer.add_alignments(enc_alignments.detach().cpu(),
                          dec_alignments.detach().cpu(),
                          enc_dec_alignments.detach().cpu(),
                          text_padded.detach().cpu(),
                          mel_lengths.detach().cpu(),
                          text_lengths.detach().cpu(),
                          iteration//hparams.accumulation, 'Validation')
    
    writer.add_gates(gate_out.detach().cpu(),
                    iteration//hparams.accumulation, 'Validation')
    model.train()