In [None]:
def decode_text(padded_text, text_lengths, batch_idx=0):
    
    text = padded_text[batch_idx]
    text_len = text_lengths[batch_idx]
    
    text = ''.join([symbols[ci] for i, ci in enumerate(text) if i < text_len])
    
    return text

In [None]:
import os, argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
from modules.model import Model
from modules.loss import MDNLoss
import hparams
from text import *
from utils.utils import *
from utils.writer import get_writer
from torch.utils.tensorboard import SummaryWriter
import math

import matplotlib.pyplot as plt

os.environ["CUDA_VISIBLE_DEVICES"]='0'
torch.manual_seed(hparams.seed)
torch.cuda.manual_seed(hparams.seed)
    
train_loader, val_loader, collate_fn = prepare_dataloaders(hparams)
model = nn.DataParallel(Model(hparams)).cuda()
criterion = MDNLoss()
writer = get_writer(hparams.output_directory, hparams.log_directory)
optimizer = torch.optim.Adam(model.parameters(),
                             lr=hparams.lr,
                             betas=(0.9, 0.98),
                             eps=1e-09)


def viterbi(log_prob_matrix):
    L, T = log_prob_matrix.size()
    log_beta = log_prob_matrix.new_ones(L,T)*(-1e15)
    log_beta[0, 0] = log_prob_matrix[0, 0]

    for t in range(1, T):
        prev_step = torch.cat([log_beta[:, t-1:t], F.pad(log_beta[:, t-1:t], (0,0,1,-1), value=-1e15)], dim=-1).max(dim=1)[0]
        log_beta[:, t] = prev_step+log_prob_matrix[:, t]
        
    j = L-1
    path = [j]
    for t in range(T-1, 0, -1):
        if j==0:
            path.append(j)
            continue
        elif log_beta[j-1,t-1:t].item()>log_beta[j,t-1:t].item():
            path.append(j-1)
            j-=1
        else:
            path.append(j)
    
    path.reverse()
    return log_prob_matrix.new_tensor(torch.LongTensor(path).unsqueeze(1))


def validate(model, criterion, val_loader, iteration, writer):
    model.eval()
    with torch.no_grad():
        n_data, val_loss = 0, 0
        for i, batch in enumerate(val_loader):
            n_data += len(batch[0])
            text_padded, text_lengths, mel_padded, mel_lengths = [
                reorder_batch(x, hparams.n_gpus).cuda() for x in batch
            ]
            mel_padded = (mel_padded - torch.min(mel_padded))/torch.max((mel_padded - torch.min(mel_padded)))

            encoder_input = model.module.Prenet(text_padded)
            hidden_states, _ = model.module.FFT_lower(encoder_input, text_lengths)
            mu_sigma = model.module.get_mu_sigma(hidden_states)
        
            mdn_loss, log_prob_matrix, log_alpha, alpha_last = criterion(mu_sigma, mel_padded, text_lengths, mel_lengths)
            val_loss += mdn_loss.item() * len(batch[0])

        val_loss /= n_data
        
    
    probable_path = viterbi(log_prob_matrix[0])
    path_oh = 1.0*(hidden_states.new_tensor(torch.arange(probable_path.max()+1)).unsqueeze(0)==probable_path)
    mel_out = torch.matmul(path_oh, mu_sigma[0, :, :hparams.n_mel_channels])
    
    plt.figure(figsize=(15,4))
    plt.imshow(path_oh.detach().cpu().t(), aspect='auto', origin='lower')
    plt.show()

    plt.figure(figsize=(15,4))
    plt.imshow(mel_padded[0].detach().cpu(), aspect='auto', origin='lower')
    plt.show()

    plt.figure(figsize=(15,4))
    plt.imshow(mel_out.detach().cpu().t(), aspect='auto', origin='lower')
    plt.show()
    
    print(decode_text(text_padded, text_lengths, batch_idx=0))
    
    '''
    writer.add_losses(mdn_loss.item(), iteration//hparams.accumulation, 'Validation')
    
    writer.add_specs(mel_padded.detach().cpu(),
                     mel_out.detach().cpu(),
                     mel_lengths.detach().cpu(),
                     iteration//hparams.accumulation, 'Validation')
    
    writer.add_alignments(probable_path.detach().cpu(),
                          text_lengths.detach().cpu(),
                          mel_lengths.detach().cpu(),
                          iteration//hparams.accumulation, 'Validation')
    '''
            
    model.train()

In [None]:
hparam_list = ([(item, getattr(hparams, item)) for item in dir(hparams) if not item.startswith("__")])

for key, item in hparam_list:
    print(f'{key:>20}: {item}')

In [None]:
iteration, loss = 0, 0
model.train()

loss_list = list()

print("Training Start!!!")
while iteration < (hparams.train_steps*hparams.accumulation):
    for i, batch in enumerate(train_loader):
        text_padded, text_lengths, mel_padded, mel_lengths = [
            x.cuda() for x in batch
        ]
        mel_padded = (mel_padded - torch.min(mel_padded))/torch.max((mel_padded - torch.min(mel_padded)))

        mdn_loss = model(text_padded,
                        mel_padded,
                        None,
                        text_lengths,
                        mel_lengths,
                        criterion,
                        stage=2)
        sub_loss = mdn_loss.mean()/hparams.accumulation
        sub_loss.backward()
        loss = loss+sub_loss.item()
        
        iteration += 1
            
        if iteration%(1*hparams.accumulation)==0:
            print(f"Iteration: {iteration} / Loss: {sub_loss.item()}")

        if iteration%hparams.accumulation == 0:
            lr_scheduling(optimizer, iteration//hparams.accumulation)
            nn.utils.clip_grad_norm_(model.parameters(), hparams.grad_clip_thresh)
            optimizer.step()
            model.zero_grad()
            writer.add_losses(loss, iteration//hparams.accumulation, 'Train')
            
            loss_list.append(loss)
            
            loss=0
            
        if iteration%(hparams.iters_per_validation*hparams.accumulation)==0:
            
            plt.figure(figsize=(20, 8))
            plt.title(f'MDN Loss Graph: iteration #{iteration}')
            plt.plot(loss_list)
            plt.show()
            
            validate(model, criterion, val_loader, iteration, writer)

        if iteration%(hparams.iters_per_checkpoint*hparams.accumulation)==0:
            
            save_checkpoint(model,
                            optimizer,
                            hparams.lr,
                            iteration//hparams.accumulation,
                            filepath=f'{hparams.output_directory}/{hparams.log_directory}')

        if iteration==(hparams.train_steps*hparams.accumulation):
            break            