In [1]:
from typing import Dict, Optional, Sequence, Tuple

import torch
import torch.nn.functional as F

from espnet2.tts.gst.style_encoder import StyleEncoder
from espnet2.tts.abs_tts import AbsTTS
from espnet2.tts.feats_extract.abs_feats_extract import AbsFeatsExtract

import argparse
from pathlib import Path
import os
from espnet2.fileio.sound_scp import SoundScpReader
import numpy as np
from tqdm import tqdm

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
def get_parser():
    """Construct the parser."""
    parser = argparse.ArgumentParser(
        description=__doc__,
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
    )
    parser.add_argument(
        "in_folder", type=Path, help="Path to the input kaldi data directory."
    )
    parser.add_argument(
        "out_folder",
        type=Path,
        help="Output folder to save the style embedding.",
    )
    return parser

In [None]:
from typing import Any


class GST(AbsTTS):
    def __init__(
        self,
        # idim: int,
        odim: int,
        feats_extract: Optional[AbsFeatsExtract],
        
        adim: int = 384,
        gst_tokens: int = 10,
        gst_heads: int = 4,
        gst_conv_layers: int = 6,
        gst_conv_chans_list: Sequence[int] = (32, 32, 64, 64, 128, 128),
        gst_conv_kernel_size: int = 3,
        gst_conv_stride: int = 2,
        gst_gru_layers: int = 1,
        gst_gru_units: int = 128,
    ):
        """GST model.
        
        Reference Code:
            /mnt/data/users/snegishi/M2/Satoru-Negishi/espnet/espnet2/tts/fastspeech2/fastspeech2.py  ,321~
        """
        self.gst = StyleEncoder(
            idim=odim,  # the input is mel-spectrogram
            gst_tokens=gst_tokens,
            gst_token_dim=adim,
            gst_heads=gst_heads,
            conv_layers=gst_conv_layers,
            conv_chans_list=gst_conv_chans_list,
            conv_kernel_size=gst_conv_kernel_size,
            conv_stride=gst_conv_stride,
            gru_layers=gst_gru_layers,
            gru_units=gst_gru_units,
        )
        self.feats_extract = feats_extract

    def extract_feats(
        self,
        speech: torch.Tensor,
        speech_lengths: torch.Tensor,
    ):
        """Extract features.

        Args:
            speech (Tensor): Input speech feature (T, D).
            speech_lengths (Tensor): The length of input speech feature (N,).

        Reference Code:
            /mnt/data/users/snegishi/M2/Satoru-Negishi/espnet/espnet2/tts/espnet_model.py  ,153~
        """
        if self.feats_extract is not None:
            feats, feats_lengths = self.feats_extract(speech, speech_lengths)
        else:
            # Use precalculated feats (feats_type != raw case)
            feats, feats_lengths = speech, speech_lengths
        feats_dict = dict(feats=feats, feats_lengths=feats_lengths)

        return feats_dict

    def __call__(self, speech, speech_lengths):
        embeds = self.extract_feats(speech, speech_lengths)
        return embeds

In [None]:
def main(argv):
    """Load the model, generate kernel and bandpass plots."""
    parser = get_parser()
    args = parser.parse_args(argv)

    # if torch.cuda.is_available() and ("cuda" in args.device):
    #     device = args.device
    # else:
    #     device = "cpu"


    # Prepare spk2utt for mean x-vector
    spk2utt = dict()
    with open(os.path.join(args.in_folder, "spk2utt"), "r") as reader:
        for line in reader:
            details = line.split()
            spk2utt[details[0]] = details[1:]

    wav_scp = SoundScpReader(os.path.join(args.in_folder, "wav.scp"), np.float32)
    os.makedirs(args.out_folder, exist_ok=True)
    # writer_utt = kaldiio.WriteHelper(
    #     "ark,scp:{0}/xvector.ark,{0}/xvector.scp".format(args.out_folder)
    # )
    # writer_spk = kaldiio.WriteHelper(
    #     "ark,scp:{0}/spk_xvector.ark,{0}/spk_xvector.scp".format(args.out_folder)
    # )
    writer_spk = {}
    gst_encoder = GST(odim=40,feats_extract=None)

    for speaker in tqdm(spk2utt):
        style_embeds = list()
        for utt in spk2utt[speaker]:
            in_sr, wav = wav_scp[utt]
            # Style Embedding
            embeds = gst_encoder(wav, in_sr)
            # writer_utt[utt] = np.squeeze(embeds)
            style_embeds.append(embeds)

        # Speaker Normalization
        embeds = np.mean(np.stack(style_embeds, 0), 0)
        writer_spk[speaker] = embeds
    # writer_utt.close()
    # writer_spk.close()