# **Тренировочный ноутбук для Tacotron2(multispeaker) + HiFiTTS** 
Обучение модели [Tacotron2](https://github.com/NVIDIA/tacotron2) (модифицированной под многоголосовую речь), на датасете HiFiTTS ([Пайплайн](https://github.com/ViktorKrasnorutskiy/tacotron2_hifitts/blob/main/train/prepare_dataset.ipynb) по подготовке датасета [HiFiTTS](http://www.openslr.org/109/))

In [None]:
# На бесплатной версии Colab желательно получить Tesla T4 (Tesla K80 - хуже)
# Сброс полученной карты - "Среда_выполнения/Отключиться_от_среды_выполнения_и_удалить_ее"
!nvidia-smi -L

In [None]:
# Подключение Гугл Диска
from google.colab import drive

drive.mount('drive')

In [None]:
# Скачивание репозитория Tacotron2
%tensorflow_version 1.x
import os

!git clone -q https://github.com/NVIDIA/tacotron2
os.chdir('tacotron2')
!git submodule init
!git submodule update
!pip install -q unidecode tensorboardX

In [None]:
# Распаковка заранее загруженного zip архива с мел-спектограммами датасета
# Необходимо загрузить данный архив на Гугл Диск
import zipfile
import os

zipped_dataset = '/content/drive/MyDrive/data/mels.zip'
z = zipfile.ZipFile(zipped_dataset, 'r')
z.extractall()

In [None]:
#@title Модифицированная модель под Multispeaker

import os
import time
import math
from numpy import finfo
import numpy as np
import tensorflow as tf
import torch
from model import LocationLayer, Attention, Prenet, Postnet, Encoder
from torch.utils.data import DataLoader
from text import text_to_sequence, symbols
import layers
import torch.utils.data
import random
from loss_function import Tacotron2Loss
from logger import Tacotron2Logger
from hparams import create_hparams
from math import sqrt
import torch
from torch.autograd import Variable
from torch import nn
from torch.nn import functional as F
from layers import ConvNorm, LinearNorm
import matplotlib.pylab as plt


device = "cuda" if torch.cuda.is_available() else "cpu"


def to_gpu(x):
    x = x.contiguous()
    if torch.cuda.is_available():
        x = x.cuda(non_blocking=True)
    return torch.autograd.Variable(x)


def create_hparams(hparams_string=None, verbose=False):
    """Create model hyperparameters. Parse nondefault from given string."""
    hparams = tf.contrib.training.HParams(
        ################################
        # Experiment Parameters        #
        ################################
        epochs=500,
        iters_per_checkpoint=1000,
        seed=1234,
        dynamic_loss_scaling=True,
        fp16_run=False,
        distributed_run=False,
        dist_backend="nccl",
        dist_url="tcp://localhost:54321",
        cudnn_enabled=True,
        cudnn_benchmark=False,
        ignore_layers=['embedding.weight'],

        ################################
        # Data Parameters             #
        ################################
        load_mel_from_disk=True,
        training_files='filelists/train.txt',
        validation_files='filelists/test.txt',
        text_cleaners=['english_cleaners'],

        ################################
        # Audio Parameters             #
        ################################
        max_wav_value=32768.0,
        sampling_rate= 22050,
        filter_length=1024,
        hop_length=256,
        win_length=1024,
        n_mel_channels=80,
        mel_fmin=0.0,
        mel_fmax=8000.0,

        ################################
        # Model Parameters             #
        ################################
        n_symbols=len(symbols),
        symbols_embedding_dim=512,

        # ADD SPEAKER
        n_speakers=10,
        speakers_embedding_dim=2,
        # ADD SPEAKER

        # Encoder parameters
        encoder_kernel_size=5,
        encoder_n_convolutions=3,
        encoder_embedding_dim=512,

        # Decoder parameters
        n_frames_per_step=1,  # currently only 1 is supported
        decoder_rnn_dim=1024,
        prenet_dim=256,
        max_decoder_steps=1000,
        gate_threshold=0.5,
        p_attention_dropout=0.1,
        p_decoder_dropout=0.1,

        # Attention parameters
        attention_rnn_dim=1024,
        attention_dim=128,

        # Location Layer parameters
        attention_location_n_filters=32,
        attention_location_kernel_size=31,

        # Mel-post processing network parameters
        postnet_embedding_dim=512,
        postnet_kernel_size=5,
        postnet_n_convolutions=5,

        ################################
        # Optimization Hyperparameters #
        ################################
        use_saved_learning_rate=False,
        learning_rate=(1e-3)/(2**0.5), #1e-3,
        weight_decay=1e-6,
        grad_clip_thresh=1.0,
        batch_size=64,
        mask_padding=True  # set model's padded outputs to padded values
    )
    if hparams_string:
        tf.logging.info('Parsing command line hparams: %s', hparams_string)
        hparams.parse(hparams_string)
    if verbose:
        tf.logging.info('Final parsed hparams: %s', hparams.values())
    return hparams


def get_mask_from_lengths(lengths):
    max_len = torch.max(lengths).item()
    ids = torch.arange(0, max_len, out=torch.LongTensor(max_len).to(device))
    mask = (ids < lengths.unsqueeze(1)).bool()
    return mask


class Decoder(nn.Module):
    def __init__(self, hparams, encoder_out_embedding_dim):
        super(Decoder, self).__init__()
        self.n_mel_channels = hparams.n_mel_channels
        self.n_frames_per_step = hparams.n_frames_per_step
        ''' ADD SPEAKERS EMB DIM TO ENCODER OUT DIM '''
        self.speakers_embedding_dim = hparams.speakers_embedding_dim
        self.encoder_embedding_dim = encoder_out_embedding_dim
        ''' ADD SPEAKERS EMB DIM TO ENCODER OUT DIM '''
        self.attention_rnn_dim = hparams.attention_rnn_dim
        self.decoder_rnn_dim = hparams.decoder_rnn_dim
        self.prenet_dim = hparams.prenet_dim
        self.max_decoder_steps = hparams.max_decoder_steps
        self.gate_threshold = hparams.gate_threshold
        self.p_attention_dropout = hparams.p_attention_dropout
        self.p_decoder_dropout = hparams.p_decoder_dropout
        self.prenet = Prenet(
            hparams.n_mel_channels * hparams.n_frames_per_step,
            [hparams.prenet_dim, hparams.prenet_dim])
        self.attention_rnn = nn.LSTMCell(
            hparams.prenet_dim + encoder_out_embedding_dim, # EDITED
            hparams.attention_rnn_dim)
        self.attention_layer = Attention(
            hparams.attention_rnn_dim, encoder_out_embedding_dim, # EDITED
            hparams.attention_dim, hparams.attention_location_n_filters,
            hparams.attention_location_kernel_size)
        self.decoder_rnn = nn.LSTMCell(
            hparams.attention_rnn_dim + encoder_out_embedding_dim, # EDITED
            hparams.decoder_rnn_dim, 1)
        self.linear_projection = LinearNorm(
            hparams.decoder_rnn_dim + encoder_out_embedding_dim, # EDITED
            hparams.n_mel_channels * hparams.n_frames_per_step)
        self.gate_layer = LinearNorm(
            hparams.decoder_rnn_dim + encoder_out_embedding_dim, # EDITED
            1, bias=True, w_init_gain='sigmoid')

    def get_go_frame(self, memory):
        B = memory.size(0)
        decoder_input = Variable(memory.data.new(
            B, self.n_mel_channels * self.n_frames_per_step).zero_())
        return decoder_input

    def initialize_decoder_states(self, memory, mask):
        B = memory.size(0)
        MAX_TIME = memory.size(1)
        self.attention_hidden = Variable(memory.data.new(
            B, self.attention_rnn_dim).zero_())
        self.attention_cell = Variable(memory.data.new(
            B, self.attention_rnn_dim).zero_())
        self.decoder_hidden = Variable(memory.data.new(
            B, self.decoder_rnn_dim).zero_())
        self.decoder_cell = Variable(memory.data.new(
            B, self.decoder_rnn_dim).zero_())
        self.attention_weights = Variable(memory.data.new(
            B, MAX_TIME).zero_())
        self.attention_weights_cum = Variable(memory.data.new(
            B, MAX_TIME).zero_())
        self.attention_context = Variable(memory.data.new(
            B,
            self.encoder_embedding_dim
        ).zero_())
        self.memory = memory
        self.processed_memory = self.attention_layer.memory_layer(memory)
        self.mask = mask

    def parse_decoder_inputs(self, decoder_inputs):
        decoder_inputs = decoder_inputs.transpose(1, 2)
        decoder_inputs = decoder_inputs.view(
            decoder_inputs.size(0),
            int(decoder_inputs.size(1)/self.n_frames_per_step), -1)
        decoder_inputs = decoder_inputs.transpose(0, 1)
        return decoder_inputs

    def parse_decoder_outputs(self, mel_outputs, gate_outputs, alignments):
        alignments = torch.stack(alignments).transpose(0, 1)
        gate_outputs = torch.stack(gate_outputs).transpose(0, 1)
        gate_outputs = gate_outputs.contiguous()
        mel_outputs = torch.stack(mel_outputs).transpose(0, 1).contiguous()
        mel_outputs = mel_outputs.view(
            mel_outputs.size(0), -1, self.n_mel_channels)
        mel_outputs = mel_outputs.transpose(1, 2)
        return mel_outputs, gate_outputs, alignments

    def decode(self, decoder_input):
        cell_input = torch.cat((decoder_input, self.attention_context), -1)
        self.attention_hidden, self.attention_cell = self.attention_rnn(
            cell_input, (self.attention_hidden, self.attention_cell))
        self.attention_hidden = F.dropout(
            self.attention_hidden, self.p_attention_dropout, self.training)
        attention_weights_cat = torch.cat(
            (self.attention_weights.unsqueeze(1),
             self.attention_weights_cum.unsqueeze(1)), dim=1)
        self.attention_context, self.attention_weights = self.attention_layer(
            self.attention_hidden, self.memory, self.processed_memory,
            attention_weights_cat, self.mask)
        self.attention_weights_cum += self.attention_weights
        decoder_input = torch.cat(
            (self.attention_hidden, self.attention_context), -1)
        self.decoder_hidden, self.decoder_cell = self.decoder_rnn(
            decoder_input, (self.decoder_hidden, self.decoder_cell))
        self.decoder_hidden = F.dropout(
            self.decoder_hidden, self.p_decoder_dropout, self.training)
        decoder_hidden_attention_context = torch.cat(
            (self.decoder_hidden, self.attention_context), dim=1)
        decoder_output = self.linear_projection(
            decoder_hidden_attention_context)
        gate_prediction = self.gate_layer(decoder_hidden_attention_context)
        return decoder_output, gate_prediction, self.attention_weights

    def forward(self, memory, decoder_inputs, memory_lengths):
        decoder_input = self.get_go_frame(memory).unsqueeze(0)
        decoder_inputs = self.parse_decoder_inputs(decoder_inputs)
        decoder_inputs = torch.cat((decoder_input, decoder_inputs), dim=0)
        decoder_inputs = self.prenet(decoder_inputs)
        self.initialize_decoder_states(
            memory, mask=~get_mask_from_lengths(memory_lengths))
        mel_outputs, gate_outputs, alignments = [], [], []
        while len(mel_outputs) < decoder_inputs.size(0) - 1:
            decoder_input = decoder_inputs[len(mel_outputs)]
            mel_output, gate_output, attention_weights = self.decode(
                decoder_input)
            mel_outputs += [mel_output.squeeze(1)]
            gate_outputs += [gate_output.squeeze(1)]
            alignments += [attention_weights]
        mel_outputs, gate_outputs, alignments = self.parse_decoder_outputs(
            mel_outputs, gate_outputs, alignments)
        return mel_outputs, gate_outputs, alignments

    def inference(self, memory):
        decoder_input = self.get_go_frame(memory)
        self.initialize_decoder_states(memory, mask=None)
        mel_outputs, gate_outputs, alignments = [], [], []
        while True:
            decoder_input = self.prenet(decoder_input)
            mel_output, gate_output, alignment = self.decode(decoder_input)
            mel_outputs += [mel_output.squeeze(1)]
            mel_outputs += [mel_output.squeeze(1)]
            gate_outputs += [gate_output]
            alignments += [alignment]
            if torch.sigmoid(gate_output.data) > self.gate_threshold:
                break
            elif len(mel_outputs) == self.max_decoder_steps:
                print("Warning! Reached max decoder steps")
                break
            decoder_input = mel_output
        mel_outputs, gate_outputs, alignments = self.parse_decoder_outputs(
            mel_outputs, gate_outputs, alignments)
        return mel_outputs, gate_outputs, alignments


class Tacotron2(nn.Module):
    def __init__(self, hparams):
        super(Tacotron2, self).__init__()
        self.mask_padding = hparams.mask_padding
        self.fp16_run = hparams.fp16_run
        self.n_mel_channels = hparams.n_mel_channels
        self.n_frames_per_step = hparams.n_frames_per_step
        self.embedding = nn.Embedding(
            hparams.n_symbols, hparams.symbols_embedding_dim)
        std = sqrt(2.0 / (hparams.n_symbols + hparams.symbols_embedding_dim))
        val = sqrt(3.0) * std  # uniform bounds for std
        self.embedding.weight.data.uniform_(-val, val)
        ''' ADD SPEAKER '''
        self.speakers_embedding = nn.Embedding(
            hparams.n_speakers,
            hparams.speakers_embedding_dim)
        encoder_out_embedding_dim = hparams.encoder_embedding_dim +  hparams.speakers_embedding_dim
        ''' ADD SPEAKER '''
        self.encoder = Encoder(hparams)
        self.decoder = Decoder(hparams, encoder_out_embedding_dim)
        self.postnet = Postnet(hparams)

    def parse_batch(self, batch):
        text_padded, input_lengths, mel_padded, gate_padded, \
            output_lengths, speaker_ids = batch
        text_padded = to_gpu(text_padded).long()
        input_lengths = to_gpu(input_lengths).long()
        max_len = torch.max(input_lengths.data).item()
        mel_padded = to_gpu(mel_padded).float()
        gate_padded = to_gpu(gate_padded).float()
        output_lengths = to_gpu(output_lengths).long()
        ''' ADD SPEAKER '''
        speaker_ids = to_gpu(speaker_ids).long()
        ''' ADD SPEAKER '''
        return (
            (text_padded, input_lengths, mel_padded, max_len, output_lengths, speaker_ids),
            (mel_padded, gate_padded))

    def parse_output(self, outputs, output_lengths=None):
        if self.mask_padding and output_lengths is not None:
            mask = ~get_mask_from_lengths(output_lengths)
            mask = mask.expand(self.n_mel_channels, mask.size(0), mask.size(1))
            mask = mask.permute(1, 0, 2)
            outputs[0].data.masked_fill_(mask, 0.0)
            outputs[1].data.masked_fill_(mask, 0.0)
            outputs[2].data.masked_fill_(mask[:, 0, :], 1e3)  # gate energies
        return outputs

    def forward(self, inputs):
        text_inputs, text_lengths, mels, max_len, output_lengths, speaker_ids = inputs
        text_lengths, output_lengths = text_lengths.data, output_lengths.data
        outputs = []
        embedded_inputs = self.embedding(text_inputs).transpose(1, 2)
        encoder_outputs = self.encoder(embedded_inputs, text_lengths)
        outputs.append(encoder_outputs)
        ''' ADD SPEAKER '''
        speaker_ids = speaker_ids.unsqueeze(1)
        embedded_speakers = self.speakers_embedding(speaker_ids)
        embedded_speakers = embedded_speakers.expand(-1, max_len, -1)
        outputs.append(embedded_speakers)
        ''' ADD SPEAKER '''
        merged_outputs = torch.cat(outputs, -1)
        mel_outputs, gate_outputs, alignments = self.decoder(
            merged_outputs, mels, memory_lengths=text_lengths)
        mel_outputs_postnet = self.postnet(mel_outputs)
        mel_outputs_postnet = mel_outputs + mel_outputs_postnet
        return self.parse_output(
            [mel_outputs, mel_outputs_postnet, gate_outputs, alignments],
            output_lengths)

    def inference(self, inputs, speaker_id):
        outputs = []
        embedded_inputs = self.embedding(inputs).transpose(1, 2)
        encoder_outputs = self.encoder.inference(embedded_inputs)
        outputs.append(encoder_outputs)
        ''' ADD SPEAKER '''
        speaker_id = torch.IntTensor([speaker_id])
        speaker_id = speaker_id.unsqueeze(1)
        embedded_speaker = self.speakers_embedding(speaker_id)
        embedded_speaker = embedded_speaker.expand(-1, encoder_outputs.shape[1], -1)
        outputs.append(embedded_speaker)
        ''' ADD SPEAKER '''
        merged_outputs = torch.cat(outputs, -1)
        mel_outputs, gate_outputs, alignments = self.decoder.inference(
            merged_outputs)
        mel_outputs_postnet = self.postnet(mel_outputs)
        mel_outputs_postnet = mel_outputs + mel_outputs_postnet
        outputs = self.parse_output(
            [mel_outputs, mel_outputs_postnet, gate_outputs, alignments])
        return outputs


def load_filepaths_and_text(filename, split="|"):
    with open(filename, encoding='utf-8') as f:
        filepaths_and_text = [line.strip().split(split) for line in f]
    return filepaths_and_text


class TextMelLoader(torch.utils.data.Dataset):
    def __init__(self, audiopaths_and_text, hparams):
        self.audiopaths_and_text = load_filepaths_and_text(audiopaths_and_text)
        self.text_cleaners = hparams.text_cleaners
        self.max_wav_value = hparams.max_wav_value
        self.sampling_rate = hparams.sampling_rate
        self.load_mel_from_disk = hparams.load_mel_from_disk
        self.stft = layers.TacotronSTFT(
            hparams.filter_length, hparams.hop_length, hparams.win_length,
            hparams.n_mel_channels, hparams.sampling_rate, hparams.mel_fmin,
            hparams.mel_fmax)
        random.seed(hparams.seed)
        random.shuffle(self.audiopaths_and_text)

    def get_mel_text_pair(self, audiopath_and_text):
        audiopath, text = audiopath_and_text[0], audiopath_and_text[1]
        ''' ADD SPEAKER '''
        speaker_id = audiopath_and_text[2]
        ''' ADD SPEAKER '''
        text = self.get_text(text)
        mel = self.get_mel(audiopath)
        return (text, mel, speaker_id)

    def get_mel(self, filename):
        mel_ndarray = np.load(filename + '.npy')
        melspec = torch.from_numpy(mel_ndarray)
        assert melspec.size(0) == self.stft.n_mel_channels, (
            'Mel dimension mismatch: given {}, expected {}'.format(
            melspec.size(0), self.stft.n_mel_channels))
        return melspec

    def get_text(self, text):
        text_norm = torch.IntTensor(text_to_sequence(text, self.text_cleaners))
        return text_norm

    def __getitem__(self, index):
        return self.get_mel_text_pair(self.audiopaths_and_text[index])

    def __len__(self):
        return len(self.audiopaths_and_text)


class TextMelCollate():
    def __init__(self, n_frames_per_step):
        self.n_frames_per_step = n_frames_per_step

    def __call__(self, batch):
        input_lengths, ids_sorted_decreasing = torch.sort(
            torch.LongTensor([len(x[0]) for x in batch]),
            dim=0, descending=True)
        max_input_len = input_lengths[0]
        text_padded = torch.LongTensor(len(batch), max_input_len)
        text_padded.zero_()
        for i in range(len(ids_sorted_decreasing)):
            text = batch[ids_sorted_decreasing[i]][0]
            text_padded[i, :text.size(0)] = text
        num_mels = batch[0][1].size(0)
        max_target_len = max([x[1].size(1) for x in batch])
        if max_target_len % self.n_frames_per_step != 0:
            max_target_len += self.n_frames_per_step - max_target_len % self.n_frames_per_step
            assert max_target_len % self.n_frames_per_step == 0
        mel_padded = torch.FloatTensor(len(batch), num_mels, max_target_len)
        mel_padded.zero_()
        gate_padded = torch.FloatTensor(len(batch), max_target_len)
        gate_padded.zero_()
        output_lengths = torch.LongTensor(len(batch))
        for i in range(len(ids_sorted_decreasing)):
            mel = batch[ids_sorted_decreasing[i]][1]
            mel_padded[i, :, :mel.size(1)] = mel
            gate_padded[i, mel.size(1)-1:] = 1
            output_lengths[i] = mel.size(1)
        ''' ADD SPEAKER '''
        speaker_ids = []
        for i in range(len(ids_sorted_decreasing)):
            speaker_id = batch[ids_sorted_decreasing[i]][2]
            speaker_ids.append(int(speaker_id))
        speaker_ids = torch.Tensor(speaker_ids)
        ''' ADD SPEAKER '''
        return text_padded, input_lengths, mel_padded, gate_padded, \
            output_lengths, speaker_ids


def prepare_dataloaders(hparams):
    trainset = TextMelLoader(hparams.training_files, hparams)
    valset = TextMelLoader(hparams.validation_files, hparams)
    collate_fn = TextMelCollate(hparams.n_frames_per_step)
    train_loader = DataLoader(trainset, num_workers=1, shuffle=True,
                              sampler=None,
                              batch_size=hparams.batch_size, pin_memory=False,
                              drop_last=True, collate_fn=collate_fn)
    return train_loader, valset, collate_fn


def prepare_directories_and_logger(output_directory, log_directory, rank):
    if rank == 0:
        if not os.path.isdir(output_directory):
            os.makedirs(output_directory)
            os.chmod(output_directory, 0o775)
        logger = Tacotron2Logger(os.path.join(output_directory, log_directory))
    else:
        logger = None
    return logger


def load_model(hparams):
    model = Tacotron2(hparams).to(device)
    return model


def warm_start_model(checkpoint_path, model, ignore_layers):
    assert os.path.isfile(checkpoint_path)
    print("Warm starting model from checkpoint '{}'".format(checkpoint_path))
    checkpoint_dict = torch.load(checkpoint_path, map_location='cpu')
    model_dict = checkpoint_dict['state_dict']
    if len(ignore_layers) > 0:
        model_dict = {k: v for k, v in model_dict.items()
                      if k not in ignore_layers}
        dummy_dict = model.state_dict()
        dummy_dict.update(model_dict)
        model_dict = dummy_dict
    model.load_state_dict(model_dict)
    return model


def load_checkpoint(checkpoint_path, model, optimizer):
    assert os.path.isfile(checkpoint_path)
    print("Loading checkpoint '{}'".format(checkpoint_path))
    checkpoint_dict = torch.load(checkpoint_path, map_location='cpu')
    model.load_state_dict(checkpoint_dict['state_dict'])
    optimizer.load_state_dict(checkpoint_dict['optimizer'])
    learning_rate = checkpoint_dict['learning_rate']
    iteration = checkpoint_dict['iteration']
    print("Loaded checkpoint '{}' from iteration {}" .format(
        checkpoint_path, iteration))
    return model, optimizer, learning_rate, iteration


def save_checkpoint(model, optimizer, learning_rate, iteration, filepath, drive_path,
                    iters_per_checkpoint):
    #print(f"Saving model and optimizer state at iteration {iteration} to {filepath}")
    #torch.save({'iteration': iteration,
    #            'state_dict': model.state_dict(),
    #            'optimizer': optimizer.state_dict(),
    #            'learning_rate': learning_rate}, filepath)
    ''' SAVE AT GOOGLE DRIVE '''
    try:
        itr = (iteration//iters_per_checkpoint)%10
        drive_path = drive_path + f'/checkpoint_{itr}'
        print(f"Saving model and optimizer state at iteration {iteration} ({itr}) to {drive_path}")
        torch.save({'iteration': iteration,
                'state_dict': model.state_dict(),
                'optimizer': optimizer.state_dict(),
                'learning_rate': learning_rate}, drive_path)
    except Exception:
        print('EXCEPT WHEN SAVING CHECKPOINT AT GOOGLE DRIVE')
    ''' SAVE AT GOOGLE DRIVE '''


''' PLOTTING '''
def plot_alignment(alignment, info=None):
    alignment_graph_height = 600
    alignment_graph_width = 1000
    %matplotlib inline
    fig, ax = plt.subplots(figsize=(int(alignment_graph_width/100), int(alignment_graph_height/100)))
    im = ax.imshow(alignment, cmap='inferno', aspect='auto', origin='lower',
                   interpolation='none')
    ax.autoscale(enable=True, axis="y", tight=True)
    fig.colorbar(im, ax=ax)
    xlabel = 'Decoder timestep'
    if info is not None:
        xlabel += '\n\n' + info
    plt.xlabel(xlabel)
    plt.ylabel('Encoder timestep')
    plt.tight_layout()
    fig.canvas.draw()
    plt.show()
''' PLOTTING '''


def validate(model, criterion, valset, iteration, batch_size, n_gpus,
             collate_fn, logger, distributed_run, rank):
    model.eval()
    with torch.no_grad():
        val_loader = DataLoader(valset, sampler=None, num_workers=1,
                                shuffle=False, batch_size=batch_size,
                                pin_memory=False, collate_fn=collate_fn)
        val_loss = 0.0
        for i, batch in enumerate(val_loader):
            x, y = model.parse_batch(batch)
            y_pred = model(x)
            loss = criterion(y_pred, y)
            reduced_val_loss = loss.item()
            val_loss += reduced_val_loss
        val_loss = val_loss / (i + 1)
    model.train()
    if rank == 0:
        print("Validation loss {}: {:9f}  ".format(iteration, val_loss))
        logger.log_validation(val_loss, model, y, y_pred, iteration)
        ''' PLOTTING '''
        %matplotlib inline
        _, mel_outputs, gate_outputs, alignments = y_pred
        idx = random.randint(0, alignments.size(0) - 1)
        plot_alignment(alignments[idx].data.cpu().numpy().T)
        ''' PLOTTING '''


def train(output_directory, log_directory, checkpoint_path, warm_start, 
          hparams, drive_path,
          n_gpus=1, rank=0, group_name='group_name'):
    print('START TRAIN')
    torch.manual_seed(hparams.seed)
    torch.cuda.manual_seed(hparams.seed)
    model = load_model(hparams)
    learning_rate = hparams.learning_rate
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate,
                                 weight_decay=hparams.weight_decay)
    criterion = Tacotron2Loss()
    logger = prepare_directories_and_logger(
        output_directory, log_directory, rank)
    train_loader, valset, collate_fn = prepare_dataloaders(hparams)
    iteration = 0
    epoch_offset = 0
    if checkpoint_path is not None:
        if warm_start:
            model = warm_start_model(
                checkpoint_path, model, hparams.ignore_layers)
        else:
            model, optimizer, _learning_rate, iteration = load_checkpoint(
                checkpoint_path, model, optimizer)
            if hparams.use_saved_learning_rate:
                learning_rate = _learning_rate
            iteration += 1
            epoch_offset = max(0, int(iteration / len(train_loader)))
    model.train()
    is_overflow = False
    # ================ MAIN TRAINNIG LOOP! ===================
    for epoch in range(epoch_offset, hparams.epochs):
        print("Epoch: {}".format(epoch))
        for i, batch in enumerate(train_loader):
            start = time.perf_counter()
            for param_group in optimizer.param_groups:
                param_group['lr'] = learning_rate
            x, y = model.parse_batch(batch)
            y_pred = model(x)
            loss = criterion(y_pred, y)
            reduced_loss = loss.item()
            loss.backward()
            grad_norm = torch.nn.utils.clip_grad_norm_(
                    model.parameters(), hparams.grad_clip_thresh)
            optimizer.step()
            model.zero_grad()
            if not is_overflow and rank == 0:
                duration = time.perf_counter() - start
                print("Train loss {} {:.6f} Grad Norm {:.6f} {:.2f}s/it".format(
                    iteration, reduced_loss, grad_norm, duration))
                logger.log_training(
                    reduced_loss, grad_norm, learning_rate, duration, iteration)
            if not is_overflow and (iteration % hparams.iters_per_checkpoint == 0):
                validate(model, criterion, valset, iteration,
                         hparams.batch_size, n_gpus, collate_fn, logger,
                         hparams.distributed_run, rank)
                if rank == 0:
                    checkpoint_path = os.path.join(
                        output_directory, "checkpoint_{}".format(iteration))
                    save_checkpoint(model, optimizer, learning_rate, iteration,
                                    checkpoint_path, drive_path, hparams.iters_per_checkpoint)
            if i % 100 == 0:
                print(i, 'of', len(train_loader))
            iteration += 1

In [None]:
hparams = create_hparams()
torch.backends.cudnn.enabled = hparams.cudnn_enabled
torch.backends.cudnn.benchmark = hparams.cudnn_benchmark
print("FP16 Run:", hparams.fp16_run)
print("Dynamic Loss Scaling:", hparams.dynamic_loss_scaling)
print("Distributed Run:", hparams.distributed_run)
print("cuDNN Enabled:", hparams.cudnn_enabled)
print("cuDNN Benchmark:", hparams.cudnn_benchmark)

# Локальные директории для сохранения логов и чекпоинтов
outdir = 'outdir'
logdir = 'logdir'

# Если обучатся будет уже предобученная модель, то необходимо загрузить данный checkpoint.pt на Гугл Диск 
# ('/content/drive/MyDrive/data/checkpoint_8800' - в данном случае в папку data, где расположен zip архив мел-спектограмм)
checkpoint_path = '/content/drive/MyDrive/data/checkpoint_8800' # None
warm_start = True # False

# Ошибки переполнения gpu памяти при обучении на Tesla-T4 не возникало
hparams.batch_size = 64

# Как часто будет сохраняться модель (каждое сохранение ~300мб)
hparams.iters_per_checkpoint = 20

# HIFITTS датасет состоит из 10 спикеров
hparams.n_speakers = 10

# Данный параметр лучше подобрать самостоятельно
hparams.speakers_embedding_dim = 2

# Необходимо создать директорию /logs в Гугл Диске
drive_path = '/content/drive/MyDrive/logs'
# Необходимо загрузить данные файлы (train.txt и test.txt) в директорию /tacotron2/filelists
hparams.training_files = 'filelists/train.txt'
hparams.validation_files = 'filelists/test.txt'

In [None]:
# Запуск обучения
train(outdir, logdir, checkpoint_path, warm_start, hparams, drive_path)

Для избежания блокировки Collab в режиме бездействия, необходимо вставить и запустить нижеуказанный код в консоли браузера (F12 - открыть режим разработчика во многих браузерах):

```
function ClickConnect(){
    console.log("Clicked on connect button"); 
    document.querySelector("colab-connect-button").click()
}
setInterval(ClickConnect,60000)
```

