In [1]:
import os, argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
from modules.model import Model
from modules.loss import MDNLoss
import hparams
from text import *
from utils.utils import *
from utils.writer import get_writer
from torch.utils.tensorboard import SummaryWriter
import math
import matplotlib.pyplot as plt


os.environ["CUDA_VISIBLE_DEVICES"]='0, 1'
torch.manual_seed(hparams.seed)
torch.cuda.manual_seed(hparams.seed)

train_loader, val_loader, collate_fn = prepare_dataloaders(hparams, stage=1)

checkpoint_path = f"training_log/aligntts/stage0/checkpoint_40000"
state_dict = {}
for k, v in torch.load(checkpoint_path)['state_dict'].items():
    state_dict[k[7:]]=v

model = Model(hparams).cuda()
model.load_state_dict(state_dict)
model = nn.DataParallel(model).cuda()

criterion = MDNLoss()
writer = get_writer(hparams.output_directory, f'{hparams.log_directory}/stage1')
optimizer = torch.optim.Adam(model.parameters(),
                             lr=hparams.lr,
                             betas=(0.9, 0.98),
                             eps=1e-09)


def validate(model, criterion, val_loader, iteration, writer):
    model.eval()
    with torch.no_grad():
        n_data, val_loss = 0, 0
        for i, batch in enumerate(val_loader):
            n_data += len(batch[0])
            text_padded, mel_padded, align_padded, text_lengths, mel_lengths = [
                reorder_batch(x, hparams.n_gpus).cuda() for x in batch
            ]

            encoder_input = model.module.Prenet(text_padded)
            hidden_states, _ = model.module.FFT_lower(encoder_input, text_lengths)
            mel_out = model.module.get_melspec(hidden_states, align_padded, mel_lengths)
            fft_loss = nn.L1Loss()(mel_out, mel_padded)
            val_loss += fft_loss.item() * len(batch[0])

        val_loss /= n_data

    writer.add_scalar('Validation loss', val_loss, iteration//hparams.accumulation)
    writer.add_specs(mel_padded[0].detach().cpu(),
                     mel_out[0].detach().cpu(),
                     iteration//hparams.accumulation, 'Validation')
            
    model.train()

In [2]:
iteration, loss = 0, 0
model.train()

while True:
    for i, batch in enumerate(train_loader):
        text_padded, mel_padded, align_padded, text_lengths, mel_lengths = [
            reorder_batch(x, hparams.n_gpus).cuda() for x in batch
        ]

        fft_loss = model(text_padded,
                         mel_padded,
                         align_padded,
                         text_lengths,
                         mel_lengths,
                         criterion,
                         stage=1)
        sub_loss = fft_loss.mean()/hparams.accumulation
        sub_loss.backward()
        loss = loss+sub_loss.item()
        iteration += 1
        
        if iteration%hparams.accumulation == 0:
            lr_scheduling(optimizer, iteration//hparams.accumulation)
            nn.utils.clip_grad_norm_(model.parameters(), hparams.grad_clip_thresh)
            optimizer.step()
            model.zero_grad()
            writer.add_scalar('Train loss', loss, iteration//hparams.accumulation)
            loss=0

        if iteration%(hparams.iters_per_validation*hparams.accumulation)==0:
            validate(model, criterion, val_loader, iteration, writer)

        if iteration%(hparams.iters_per_checkpoint*hparams.accumulation)==0:
            save_checkpoint(model,
                            optimizer,
                            hparams.lr,
                            iteration//hparams.accumulation,
                            filepath=f'{hparams.output_directory}/{hparams.log_directory}/stage1')

        if iteration==(hparams.train_steps[1]*hparams.accumulation):
            break
            
    if iteration==(hparams.train_steps[1]*hparams.accumulation):
        break

  ids = lengths.new_tensor(torch.arange(0, max_len))


Saving model and optimizer state at iteration 10000 to training_log/aligntts/stage1
Saving model and optimizer state at iteration 20000 to training_log/aligntts/stage1
Saving model and optimizer state at iteration 30000 to training_log/aligntts/stage1
Saving model and optimizer state at iteration 40000 to training_log/aligntts/stage1
