# Retrain Multi-Tacotron2 trained by jvs-dataset
### 考慮点 
##### speaker_idを入力に含めて学習を行う。
----

#### Import

In [27]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [28]:
from functools import partial
from logging import Logger
from pathlib import Path

import torch
from matplotlib import pyplot as plt
from omegaconf import DictConfig
from torch import nn
from tqdm import tqdm
import numpy as np
import hydra
from hydra.utils import to_absolute_path
from hydra import compose, initialize
from hydra.core.global_hydra import GlobalHydra
from omegaconf import OmegaConf
import importlib

In [29]:
# from ttslearn.contrib.multispk_util import collate_fn_ms_tacotron, setup
# from ttslearn.util import make_non_pad_mask
# from ttslearn.train_util import (get_epochs_with_optional_tqdm,plot_2d_feats,plot_attention,save_checkpoint,)

from utils.multispk_util import collate_fn_ms_tacotron, setup
from utils import multispk_util
from utils.util import make_non_pad_mask, make_pad_mask, load_utt_list
from utils.train_util import (
    get_epochs_with_optional_tqdm,
    plot_2d_feats,
    plot_attention,
    save_checkpoint,
)
from ttslearn.tacotron.frontend.openjtalk import sequence_to_text

from utils.early_stopping import EarlyStopping

logger: Logger = None

#### spk_idの指定＋保存
----

In [30]:
np.save(
    "data/fine-spk.npy",
    np.array([52], dtype=np.int64),
    allow_pickle=False,
)

#### Config読み込みの関数を定義
----

モデルや学習の際の調整パラメータをhydraを使用してconfigディレクトリから取得する

In [31]:
def load_config():
    hydra_instance = GlobalHydra.instance()
    if not hydra_instance.is_initialized():
        hydra_instance.clear()
        initialize(config_path="conf/train_tacotron")
    
    cfg = compose(config_name="config")
    return cfg


In [32]:
# print(config.data.train.utt_list)

#### データ数の設定

In [33]:
train_data_num = 100
dev_data_num = 100
is_early_stopping = False

### Trainデータ数を指定して.listファイルとして作成

#### 学習手順定義
----

In [34]:
def train_step(
    model,
    optimizer,
    lr_scheduler,
    train,
    criterions,
    in_feats,
    in_lens,
    out_feats,
    out_lens,
    stop_flags,
    spk_ids,
):
    optimizer.zero_grad()

    # Run forwaard
    outs, outs_fine, logits, _ = model(in_feats, in_lens, out_feats, spk_ids)

    # Mask (B x T x 1)
    # 損失を求めるためpadding部分を取り除く
    mask = make_non_pad_mask(out_lens).unsqueeze(-1).to(out_feats.device)
    out_feats = out_feats.masked_select(mask)
    outs = outs.masked_select(mask)
    outs_fine = outs_fine.masked_select(mask)
    stop_flags = stop_flags.masked_select(mask.squeeze(-1))
    logits = logits.masked_select(mask.squeeze(-1))

    # Loss
    decoder_out_loss = criterions["out_loss"](outs, out_feats)
    postnet_out_loss = criterions["out_loss"](outs_fine, out_feats)
    stop_token_loss = criterions["stop_token_loss"](logits, stop_flags)
    loss = decoder_out_loss + postnet_out_loss + stop_token_loss

    loss_values = {
        "DecoderOutLoss": decoder_out_loss.item(),
        "PostnetOutLoss": postnet_out_loss.item(),
        "StopTokenLoss": stop_token_loss.item(),
        "Loss": loss.item(),
    }

    # Update
    if train:
        loss.backward()
        grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        if not torch.isfinite(grad_norm):
            logger.info("grad norm is NaN. Skip updating")
        else:
            optimizer.step()
        lr_scheduler.step()

    return loss_values

#### モデル評価関数の定義
----

In [35]:
@torch.no_grad()
def eval_model(
    step, model, writer, in_feats, in_lens, out_feats, out_lens, spk_ids, is_inference
):
    # 最大3つまで
    N = min(len(in_feats), 3)

    if is_inference:
        outs, outs_fine, att_ws, out_lens = [], [], [], []
        for idx in range(N):
            out, out_fine, _, att_w = model.inference(
                in_feats[idx][: in_lens[idx]], spk_ids[idx]
            )
            outs.append(out)
            outs_fine.append(out_fine)
            att_ws.append(att_w)
            out_lens.append(len(out))
    else:
        outs, outs_fine, _, att_ws = model(in_feats, in_lens, out_feats, spk_ids)

    for idx in range(N):
        text = "".join(
            sequence_to_text(in_feats[idx][: in_lens[idx]].cpu().data.numpy())
        )
        if is_inference:
            group = f"utt{idx+1}_inference"
        else:
            group = f"utt{idx+1}_teacher_forcing"

        out = outs[idx][: out_lens[idx]]
        out_fine = outs_fine[idx][: out_lens[idx]]
        rf = model.decoder.reduction_factor
        att_w = att_ws[idx][: out_lens[idx] // rf, : in_lens[idx]]
        fig = plot_attention(att_w)
        writer.add_figure(f"{group}/attention", fig, step)
        plt.close()
        fig = plot_2d_feats(out, text)
        writer.add_figure(f"{group}/out_before_postnet", fig, step)
        plt.close()
        fig = plot_2d_feats(out_fine, text)
        writer.add_figure(f"{group}/out_after_postnet", fig, step)
        plt.close()
        if not is_inference:
            out_gt = out_feats[idx][: out_lens[idx]]
            fig = plot_2d_feats(out_gt, text)
            writer.add_figure(f"{group}/out_ground_truth", fig, step)
            plt.close()


#### 学習ループ処理定義
----


    nepochs loop ----->
    ・loss_params (intervalごとにparamsを保存)

        train & dev loop --->
        ・ ave_loss (Epoch ごとのロスを記録) <- from running_loss
        ・ best_loss_params (最小のlossの更新ごとに保存) if train
    
            batch loop ---------->
            ・sort
            ・train_step (損失を元にparamの更新)
            ・eval_model (intervalごとに) if eval and first data
        
    ・latest_loss_param (保存)

##### 保存場所
- writer : tensorboard/exp
- save_checkpoint : exp
- logger : shellログ

In [36]:
def _update_running_losses_(running_losses, loss_values):
    for key, val in loss_values.items():
        try:
            running_losses[key] += val
        except KeyError:
            running_losses[key] = val


def train_loop(config, device, model, optimizer, lr_scheduler, data_loaders, writer, is_early_stopping):
    early_stopping = EarlyStopping(patience=5)
    print(" Eearly Stopping Mode is : "+ str(is_early_stopping))
    es_flag = False

    criterions = {
        "out_loss": nn.MSELoss(),
        "stop_token_loss": nn.BCEWithLogitsLoss(),
    }

    out_dir = Path(to_absolute_path(config.train.out_dir))
    best_loss = torch.finfo(torch.float32).max
    train_iter = 1
    nepochs = config.train.nepochs
    print("nepochs : " + str(nepochs))
    print("iter per epochs: " + str(train_data_num / 32))
    print("itar sum : " + str(config.train.max_train_steps))
    # for epoch in get_epochs_with_optional_tqdm(config.tqdm, nepochs):
    for epoch in range(1, nepochs + 1):
        for phase in data_loaders.keys():
            train = phase.startswith("train")
            model.train() if train else model.eval()
            running_losses = {}
            for idx, (
                in_feats,
                in_lens,
                out_feats,
                out_lens,
                stop_flags,
                spk_ids,
            ) in tqdm(
                enumerate(data_loaders[phase]), desc=f"{phase} iter"
            ):
                # ミニバッチのソート (短い順)
                in_lens, indices = torch.sort(in_lens, dim=0, descending=True)
                in_feats, out_feats, out_lens = (
                    in_feats[indices].to(device),
                    out_feats[indices].to(device),
                    out_lens[indices].to(device),
                )
                stop_flags = stop_flags[indices].to(device)
                spk_ids = spk_ids[indices].to(device)

                loss_values = train_step(
                    model,
                    optimizer,
                    lr_scheduler,
                    train,
                    criterions,
                    in_feats,
                    in_lens,
                    out_feats,
                    out_lens,
                    stop_flags,
                    spk_ids,
                )

                # memo each loss of tacotron2 & lr per batch iter
                if train:
                    for key, val in loss_values.items():
                        writer.add_scalar(f"{key}ByStep/train", val, train_iter)
                    writer.add_scalar(
                        "LearningRate", lr_scheduler.get_last_lr()[0], train_iter
                    )
                    train_iter += 1
                _update_running_losses_(running_losses, loss_values)

                # 最初の検証用データに対して、中間結果の可視化
                if (
                    not train
                    and idx == 0
                    and epoch % config.train.eval_epoch_interval == 0
                ):
                    for is_inference in [False, True]:
                        eval_model(
                            train_iter,
                            model,
                            writer,
                            in_feats,
                            in_lens,
                            out_feats,
                            out_lens,
                            spk_ids,
                            is_inference,
                        )

            # Epoch ごとのロスを出力
            for key, val in running_losses.items():
                ave_loss = val / len(data_loaders[phase])
                writer.add_scalar(f"{key}/{phase}", ave_loss, epoch)

            ave_loss = running_losses["Loss"] / len(data_loaders[phase])
            if not train:
                # 早期終了をするかのフラッグを立てる。epoch loopでbreakをかける
                if is_early_stopping:
                    es_flag = early_stopping(ave_loss)

                if ave_loss < best_loss:
                    best_loss = ave_loss
                    save_checkpoint(logger, out_dir, model, optimizer, epoch, True)

        if epoch % config.train.checkpoint_epoch_interval == 0:
            save_checkpoint(logger, out_dir, model, optimizer, epoch, False)

        if es_flag:
            break

    # save at last epoch
    save_checkpoint(logger, out_dir, model, optimizer, nepochs)
    logger.info(f"The best loss was {best_loss}")

    return model

## 実行処理
----

In [37]:
import random

def choose_train_list(config, num):
    utt_ids = load_utt_list(to_absolute_path(config.data.train.utt_list))
    resampled = random.sample(utt_ids, num)
    with open(f"./data/train_{num}.list", "w") as list:
        for i in resampled:
            list.write(str(i) + "\n")
    config.data.train.utt_list = f"./data/train_{num}.list"

def choose_dev_list(config, num):
    utt_ids = load_utt_list(to_absolute_path(config.data.dev.utt_list))
    resampled = random.sample(utt_ids, num)
    with open(f"./data/dev_{num}.list", "w") as list:
        for i in resampled:
            list.write(str(i) + "\n")
    config.data.dev.utt_list = f"./data/dev_{num}.list"

In [38]:
def my_app():
    global logger
    config = load_config()

    choose_train_list(config, train_data_num)
    choose_dev_list(config, dev_data_num)

    device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")

    collate_fn = partial(
        collate_fn_ms_tacotron, reduction_factor=config.model.netG.reduction_factor # configはconf/train_tacotron内のすべてのファイルでのyamlが一つのconfigとして読み込まれる
    )

    model, optimizer, lr_scheduler, data_loaders, writer, logger = multispk_util.setup(
        config, device, collate_fn
    )
    print(config.train.pretrained.checkpoint)
    train_loop(config, device, model, optimizer, lr_scheduler, data_loaders, writer, is_early_stopping)



In [39]:
importlib.reload(multispk_util)
my_app()



conf/multspk_tacotron2_hifipwg_jvs24k/acoustic_model.pth
 Eearly Stopping Mode is : False
nepochs : 250
iter per epochs: 3.125
itar sum : 1000


train iter: 4it [02:06, 31.61s/it]
dev iter: 4it [00:27,  6.79s/it]
train iter: 4it [01:48, 27.09s/it]
dev iter: 4it [00:29,  7.32s/it]
train iter: 4it [01:40, 25.05s/it]
dev iter: 4it [00:25,  6.33s/it]
train iter: 4it [01:36, 24.11s/it]
dev iter: 4it [00:25,  6.31s/it]
train iter: 4it [01:28, 22.18s/it]
dev iter: 4it [00:26,  6.67s/it]
train iter: 4it [02:16, 34.06s/it]
dev iter: 4it [00:26,  6.72s/it]
train iter: 4it [01:34, 23.60s/it]
dev iter: 4it [00:27,  6.82s/it]
train iter: 4it [02:08, 32.13s/it]
dev iter: 4it [10:06, 151.69s/it]
train iter: 4it [01:44, 26.06s/it]
dev iter: 4it [00:25,  6.29s/it]
train iter: 4it [01:32, 23.01s/it]
dev iter: 4it [00:25,  6.32s/it]
train iter: 4it [2:32:05, 2281.29s/it]
dev iter: 4it [15:26, 231.57s/it]
train iter: 4it [1:00:49, 912.27s/it]
dev iter: 4it [14:39, 219.83s/it]
train iter: 4it [1:34:18, 1414.67s/it]
dev iter: 4it [57:58, 869.58s/it] 
train iter: 4it [03:57, 59.40s/it] 
dev iter: 4it [00:24,  6.21s/it]
train iter: 4it [05:41, 85.28s/

KeyboardInterrupt: 