In [1]:
import os, argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
from modules.model import Model
from modules.loss import MDNLoss
import hparams
from text import *
from utils.utils import *
from utils.writer import get_writer
from torch.utils.tensorboard import SummaryWriter
import math

os.environ["CUDA_VISIBLE_DEVICES"]='0, 1'
torch.manual_seed(hparams.seed)
torch.cuda.manual_seed(hparams.seed)
    
train_loader, val_loader, collate_fn = prepare_dataloaders(hparams)
model = nn.DataParallel(Model(hparams)).cuda()
criterion = MDNLoss()
optimizer = torch.optim.Adam(model.parameters(),
                             lr=hparams.lr,
                             betas=(0.9, 0.98),
                             eps=1e-09)

In [2]:
iteration, loss = 0, 0
model.train()
print("Training Start!!!")
for i, batch in enumerate(train_loader):
    text_padded, text_lengths, mel_padded, mel_lengths = [
        reorder_batch(x, hparams.n_gpus).cuda() for x in batch
    ]
    mel_padded = (mel_padded - torch.min(mel_padded))/torch.max((mel_padded - torch.min(mel_padded)))

    mdn_loss = model(text_padded,
                     mel_padded,
                     None,
                     text_lengths,
                     mel_lengths,
                     criterion,
                     stage=0)
    sub_loss = mdn_loss.mean()/hparams.accumulation
    sub_loss.backward()
    loss = loss+sub_loss.item()
    if i%10==0: print(f'Loss ({i}): {loss}')

    iteration += 1
    if iteration%hparams.accumulation == 0:
        lr_scheduling(optimizer, iteration//hparams.accumulation)
        nn.utils.clip_grad_norm_(model.parameters(), hparams.grad_clip_thresh)
        optimizer.step()
        model.zero_grad()
        loss=0

Training Start!!!


  ids = lengths.new_tensor(torch.arange(0, max_len))


Loss (0): 7083.978515625
Loss (10): 7006.9267578125
Loss (20): 6540.830078125
Loss (30): 6789.6240234375
Loss (40): 6518.220703125
Loss (50): 5833.9619140625
Loss (60): 5603.52734375
Loss (70): 5355.4697265625
Loss (80): 5085.9853515625
Loss (90): 4981.39453125
Loss (100): 4892.7568359375
Loss (110): 4700.94921875
Loss (120): 4823.7666015625
Loss (130): 4404.1572265625
Loss (140): 4422.609375
Loss (150): 4080.638427734375


KeyboardInterrupt: 

#    
#    
#    
#    
#    
#    
#    
#    

In [None]:
def validate(model, criterion, val_loader, iteration, writer):
    model.eval()
    with torch.no_grad():
        n_data, val_loss = 0, 0
        for i, batch in enumerate(val_loader):
            n_data += len(batch[0])
            text_padded, text_lengths, mel_padded, mel_lengths, gate_padded = [
                x.cuda() for x in batch
            ]
            
            mel_out, mel_out_post,\
            enc_alignments, dec_alignments, enc_dec_alignments, gate_out = model.module.outputs(text_padded,
                                                                                                mel_padded,
                                                                                                text_lengths,
                                                                                                mel_lengths)
        
            mel_loss, bce_loss, guide_loss = criterion((mel_out, mel_out_post, gate_out),
                                                       (mel_padded, gate_padded),
                                                       (enc_dec_alignments, text_lengths, mel_lengths))
            
            loss = torch.mean(mel_loss+bce_loss+guide_loss)
            val_loss += loss.item() * len(batch[0])

        val_loss /= n_data

    writer.add_losses(mel_loss.item(),
                      bce_loss.item(),
                      guide_loss.item(),
                      iteration//hparams.accumulation, 'Validation')
    
    writer.add_specs(mel_padded.detach().cpu(),
                     mel_out.detach().cpu(),
                     mel_out_post.detach().cpu(),
                     mel_lengths.detach().cpu(),
                     iteration//hparams.accumulation, 'Validation')
    
    writer.add_alignments(enc_alignments.detach().cpu(),
                          dec_alignments.detach().cpu(),
                          enc_dec_alignments.detach().cpu(),
                          text_padded.detach().cpu(),
                          mel_lengths.detach().cpu(),
                          text_lengths.detach().cpu(),
                          iteration//hparams.accumulation, 'Validation')
    
    writer.add_gates(gate_out.detach().cpu(),
                    iteration//hparams.accumulation, 'Validation')
    model.train()