In [1]:
import os, argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
from modules.model import Model
from modules.loss import MDNLoss
import hparams
from text import *
from utils.utils import *
from utils.writer import get_writer
from torch.utils.tensorboard import SummaryWriter
import math
import matplotlib.pyplot as plt
from tqdm import tqdm_notebook as tqdm


os.environ["CUDA_VISIBLE_DEVICES"]='0, 1'
torch.manual_seed(hparams.seed)
torch.cuda.manual_seed(hparams.seed)

train_loader, val_loader, collate_fn = prepare_dataloaders(hparams, stage=1)

# checkpoint_path = f"training_log/aligntts/checkpoint_40000"
checkpoint_path = f"training_log/aligntts/checkpoint_100000"
state_dict = {}
for k, v in torch.load(checkpoint_path)['state_dict'].items():
    state_dict[k[7:]]=v

model = Model(hparams).cuda()
model.load_state_dict(state_dict)
model = nn.DataParallel(model).cuda()

criterion = MDNLoss()
writer = get_writer(hparams.output_directory, f'{hparams.log_directory}/stage1')
optimizer = torch.optim.Adam(model.parameters(),
                             lr=hparams.lr,
                             betas=(0.9, 0.98),
                             eps=1e-09)


def validate(model, criterion, val_loader, iteration, writer):
    model.eval()
    with torch.no_grad():
        n_data, val_loss = 0, 0
        for i, batch in enumerate(val_loader):
            n_data += len(batch[0])
            text_padded, mel_padded, align_padded, text_lengths, mel_lengths = [
                reorder_batch(x, hparams.n_gpus).cuda() for x in batch
            ]
            mel_padded = (mel_padded - torch.min(mel_padded))\
                         / torch.max((mel_padded - torch.min(mel_padded)))

            encoder_input = model.module.Prenet(text_padded)
            hidden_states, _ = model.module.FFT_lower(encoder_input, text_lengths)
            mel_out = model.module.get_melspec(hidden_states, align_padded, mel_lengths)
            fft_loss = nn.L1Loss()(mel_out, mel_padded)
            val_loss += fft_loss.item() * len(batch[0])

        val_loss /= n_data

    print(f"Validation Loss: {val_loss}")
    plt.figure(figsize=(15,4))
    plt.imshow(mel_padded[0].detach().cpu(), aspect='auto', origin='lower')
    plt.show()

    plt.figure(figsize=(15,4))
    plt.imshow(mel_out[0].detach().cpu(), aspect='auto', origin='lower')
    plt.show()
    
    '''
    writer.add_losses(fft_loss.item(), iteration//hparams.accumulation, 'Validation')
    
    writer.add_specs(mel_padded.detach().cpu(),
                     mel_out.detach().cpu(),
                     mel_lengths.detach().cpu(),
                     iteration//hparams.accumulation, 'Validation')
    
    writer.add_alignments(probable_path.detach().cpu(),
                          text_lengths.detach().cpu(),
                          mel_lengths.detach().cpu(),
                          iteration//hparams.accumulation, 'Validation')
    '''
            
    model.train()

In [2]:
iteration, loss = 0, 0

iteration, _ = load_checkpoint(model, optimizer, None, 
                               f'{hparams.output_directory}/{hparams.log_directory}/stage1')

model.train()

Loading model and optimizer state at training_log/aligntts/stage1/checkpoint_90000


DataParallel(
  (module): Model(
    (Prenet): Prenet(
      (Embedding): Embedding(119, 256)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (FFT_lower): FFT(
      (FFT_layers): ModuleList(
        (0): TransformerEncoderLayer(
          (self_attn): MultiheadAttention(
            (out_proj): Linear(in_features=256, out_features=256, bias=True)
          )
          (linear1): Linear(in_features=256, out_features=1024, bias=True)
          (linear2): Linear(in_features=1024, out_features=256, bias=True)
          (norm1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
          (norm2): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (1): TransformerEncoderLayer(
          (self_attn): MultiheadAttention(
            (out_proj): Linear(in_features=256, out_features=256, bias=True)
          )
          (linear1): Linear(in_features=256, out_features=1024, bias=True)
          (linear2): Linear

In [None]:
print("Training Start!!!")



while True:
    for i, batch in tqdm(enumerate(train_loader)):
        text_padded, mel_padded, align_padded, text_lengths, mel_lengths = [
            reorder_batch(x, hparams.n_gpus).cuda() for x in batch
        ]
        mel_padded = (mel_padded - torch.min(mel_padded))\
                     / torch.max((mel_padded - torch.min(mel_padded)))

        fft_loss = model(text_padded,
                         mel_padded,
                         align_padded,
                         text_lengths,
                         mel_lengths,
                         criterion,
                         stage=1)
        sub_loss = fft_loss.mean()/hparams.accumulation
        sub_loss.backward()
        loss = loss+sub_loss.item()
        iteration += 1
        
        if iteration%hparams.accumulation == 0:
            lr_scheduling(optimizer, iteration//hparams.accumulation)
            nn.utils.clip_grad_norm_(model.parameters(), hparams.grad_clip_thresh)
            optimizer.step()
            model.zero_grad()
            writer.add_losses(loss, iteration//hparams.accumulation, 'Train')
            loss=0


        if iteration%(hparams.iters_per_validation*hparams.accumulation)==0:
            validate(model, criterion, val_loader, iteration, writer)

        if iteration%(hparams.iters_per_checkpoint*hparams.accumulation)==0:
            save_checkpoint(model,
                            optimizer,
                            hparams.lr,
                            iteration//hparams.accumulation,
                            filepath=f'{hparams.output_directory}/{hparams.log_directory}/stage1')

        if iteration==(hparams.train_steps[1]*hparams.accumulation):
            break
            
    if iteration==(hparams.train_steps[1]*hparams.accumulation):
            break

Training Start!!!


Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  for i, batch in tqdm(enumerate(train_loader)):


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

  ids = lengths.new_tensor(torch.arange(0, max_len))
