In [1]:
import os
import tgt
import glob
import tqdm
import json
import torch
import scipy
import random
import librosa
import sklearn
import speechbrain
import numpy as np
import pyworld as pw
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches

from text import _clean_text
from pathlib import Path
from IPython.display import Audio
from matplotlib.lines import Line2D
from torch.utils.tensorboard import SummaryWriter
from sklearn.preprocessing import StandardScaler
from speechbrain.lobes.models.FastSpeech2 import mel_spectogram

In [2]:
##############################################
# 1. Paths
##############################################
DATA_PATH           = '/workspace/data/EmoV-DB'
CORPUS_PATH         = '/workspace/montreal_forced_aligner/corpus'
TEXTGRID_PATH       = '/workspace/montreal_forced_aligner/aligned'
PREPROCESSED_PATH   = '/workspace/preprocessed'
EXPERIMENT_PATH     = '/workspace/experiments'


##############################################
# 2. Preprocessing
##############################################
NOISE_SYMBOL        = ' [noise] '
SPEAKERS            = ['bea', 'jenie', 'josh', 'sam']
EMOTIONS            = ['neutral', 'amused', 'angry', 'disgusted', 'sleepy']
SIL_PHONES          = ['sil', 'spn', 'sp', '']
VALID_TOKENS        = ['@'] + speechbrain.utils.text_to_sequence.valid_symbols + SIL_PHONES
PITCH_AVERAGING     = False
ENERGY_AVERAGING    = False
MATCH_TRANSCRIPT    = True


##############################################
# 3. Audio (optimized for vocoder)
##############################################
SAMPLING_RATE       = 16000
HOP_LENGTH          = 256
WIN_LENGTH          = 1024
N_FFT               = 1024
N_MELS              = 80
F_MIN               = 0.0
F_MAX               = 8000.0


##############################################
# 4. Training
##############################################
N_EPOCHS            = 100
MAX_ITERATIONS      = 50000
BATCH_SIZE          = 16
LEARNING_RATE       = 0.000001


##############################################
# 5. Model
##############################################
N_ENCODER_LAYERS    = 4
N_HEADS             = 2
HIDDEN_DIM          = 256
KERNEL_SIZE         = 9
DROPOUT             = 0.1
ALPHA               = 0.1       # mixup
BETA                = 1.0       # rank


##############################################
# 6. Miscellaneous
##############################################
MARKER              = ['o', '^', 's', 'd']
COLORS              = ['#7C00FE', '#F9E400', '#FFAF00', '#F5004F', '#00B2A9']

In [3]:
# train, test 데이터셋 분리
train_list, valid_list = [], []
for speaker in SPEAKERS:
    paths = glob.glob(os.path.join(PREPROCESSED_PATH, speaker, '*.npz'))
    random.shuffle(paths)

    n_train = int(len(paths) * 0.8)
    train_list.extend(paths[:n_train])
    valid_list.extend(paths[n_train:])


if not os.path.exists(os.path.join(PREPROCESSED_PATH, 'fs2_train.txt')):
    # train, valid 데이터셋을 파일로 저장
    with open(os.path.join(PREPROCESSED_PATH, 'fs2_train.txt'), 'w') as f:
        f.write('\n'.join(train_list) + '\n')

    with open(os.path.join(PREPROCESSED_PATH, 'fs2_valid.txt'), 'w') as f:
        f.write('\n'.join(valid_list) + '\n')
else:
    print('Skipping')

Skipping


In [4]:
def phoneme2sequence(phoneme):
    seq = [VALID_TOKENS.index(token) for token in phoneme]
    return seq

def sequence2phoneme(sequence):
    phoneme = [VALID_TOKENS[i] for i in sequence]
    return phoneme

In [5]:
class FastSpeech2Dataset(torch.utils.data.Dataset):

    def __init__(self, mode='train'):
        super(FastSpeech2Dataset, self).__init__()
        
        self.data_paths = []
        with open(os.path.join(PREPROCESSED_PATH, f'fs2_{mode}.txt'), 'r') as f:
            self.data_paths = [line.strip() for line in f.readlines()]
        
    def __len__(self):
        return len(self.data_paths)

    def __getitem__(self, idx):
        data_path = self.data_paths[idx]
        data = np.load(data_path, allow_pickle=True)
        
        # Load features
        mel = data['mel']
        pitch = data['pitch']
        energy = data['energy']
        duration = data['durations']
        phoneme = data['phones'].tolist()

        # metadata
        speaker = data['speaker'].item()
        emotion = data['emotion'].item()
        text = data['transcript'].item().replace(NOISE_SYMBOL.strip(), '').strip()
        audio_path = data['audio_path'].item()

        
        return {
            'mel': torch.FloatTensor(mel),
            'pitch': torch.FloatTensor(pitch),
            'energy': torch.FloatTensor(energy),
            'duration': torch.LongTensor(duration),
            'phoneme': torch.LongTensor(phoneme2sequence(phoneme)),
            'speaker': torch.tensor(SPEAKERS.index(speaker), dtype=torch.long),
            'emotion': torch.tensor(EMOTIONS.index(emotion), dtype=torch.long),
            'text': text,
            'audio_path': audio_path
        }


dataset = FastSpeech2Dataset(mode='train')
for data in tqdm.notebook.tqdm(dataset):
    print('Melspectrogram shape:', data['mel'].shape)
    print('Pitch shape:', data['pitch'].shape)
    print('Energy shape:', data['energy'].shape)
    print('Duration shape:', data['duration'].shape)
    print('Phoneme sequence:', data['phoneme'].shape)
    print('*Total duration:', data['duration'].sum().item())
    print('Speaker index:', data['speaker'])
    print('Emotion index:', data['emotion'])
    print('Text:', data['text'])
    break

  0%|          | 0/5498 [00:00<?, ?it/s]

Melspectrogram shape: torch.Size([80, 305])
Pitch shape: torch.Size([305])
Energy shape: torch.Size([305])
Duration shape: torch.Size([45])
Phoneme sequence: torch.Size([45])
*Total duration: 305
Speaker index: tensor(0)
Emotion index: tensor(4)
Text: the promoter's eyes were heavy, with little puffy bags under them.


In [6]:
class TextMelCollateWithAlignment:

    def __call__(self, batch):

        # Right zero-pad all one-hot text sequences to the max input length
        input_lengths, ids_sorted_decreasing = torch.sort(
            torch.LongTensor([len(x['phoneme']) for x in batch]),
            dim=0, descending=True)

    
        max_input_len = input_lengths[0]

        phoneme_padded = torch.LongTensor(len(batch), max_input_len)
        phoneme_padded.zero_()
        duration_padded = torch.LongTensor(len(batch), max_input_len)
        duration_padded.zero_()

        for i in range(len(ids_sorted_decreasing)):
            phoneme = batch[ids_sorted_decreasing[i]]['phoneme']
            phoneme_padded[i, :phoneme.size(0)] = phoneme
            duration = batch[ids_sorted_decreasing[i]]['duration']
            duration_padded[i, :duration.size(0)] = duration

        # Right zero-pad mel-spec
        num_mels = batch[0]['mel'].size(0)
        max_target_len = max([x['mel'].size(1) for x in batch])

        # include mel padded and gate padded
        mel_padded = torch.FloatTensor(len(batch), num_mels, max_target_len)
        mel_padded.zero_()
        pitch_padded = torch.FloatTensor(len(batch), max_target_len)
        pitch_padded.zero_()
        energy_padded = torch.FloatTensor(len(batch), max_target_len)
        energy_padded.zero_()
        output_lengths = torch.LongTensor(len(batch))
        labels, wavs = [], []
        speakers = torch.LongTensor(len(batch))

        for i in range(len(ids_sorted_decreasing)):
            idx = ids_sorted_decreasing[i]
            mel = batch[idx]['mel']
            pitch = batch[idx]['pitch']
            energy = batch[idx]['energy']
            mel_padded[i, :, :mel.size(1)] = mel
            pitch_padded[i, :pitch.size(0)] = pitch
            energy_padded[i, :energy.size(0)] = energy
            output_lengths[i] = mel.size(1)
            labels.append(batch[idx]['text'])
            wavs.append(batch[idx]['audio_path'])
            speakers[i] = batch[idx]['speaker']

        mel_padded = mel_padded.permute(0, 2, 1)
        return (
            phoneme_padded,
            speakers,
            input_lengths,
            mel_padded,
            pitch_padded,
            energy_padded,
            duration_padded,
            output_lengths,
            labels,
            wavs,
        )

dataset = FastSpeech2Dataset(mode='train')
collate_fn = TextMelCollateWithAlignment()
dataloader = torch.utils.data.DataLoader(
    dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    collate_fn=collate_fn,
    num_workers=4,
    pin_memory=True,
    drop_last=True,
)

for batch in dataloader:
    phoneme, speakers, input_lengths, mel, pitch, energy, duration, output_lengths, labels, wavs = batch
    print('Phoneme shape:', phoneme.shape)
    print('Input lengths:', input_lengths)
    print('Mel shape:', mel.shape)
    print('Pitch shape:', pitch.shape)
    print('Energy shape:', energy.shape)
    print('Duration shape:', duration.shape)
    print('Output lengths:', output_lengths)
    print('Speakers:', speakers)  # Print first 5 speaker indices
    print('Labels:', labels[:5])  # Print first 5 labels
    print('Wavs:', wavs[:5])      # Print first 5 audio paths
    break

Phoneme shape: torch.Size([16, 58])
Input lengths: tensor([3, 3, 0, 2, 3, 3, 3, 2, 3, 0, 3, 0, 0, 2, 3, 3])
Mel shape: torch.Size([16])
Pitch shape: torch.Size([16, 506, 80])
Energy shape: torch.Size([16, 506])
Duration shape: torch.Size([16, 506])
Output lengths: tensor([[ 5,  4, 11, 13, 15,  5, 10,  5,  8, 26,  3,  7,  3,  7, 12, 25,  6, 12,
          5,  3,  5,  7,  2,  7,  3,  3,  3,  4, 13,  2,  3,  3,  5,  2, 10, 14,
         39,  4,  5,  4, 11, 10,  2,  9,  4,  8, 68,  7,  4,  2,  2,  6, 14, 13,
         10,  6,  4,  3],
        [ 2,  6,  3,  8,  5,  8,  5,  3,  3,  9,  7,  8, 19,  5,  5, 12, 18,  9,
         13, 10,  2, 13, 23,  2,  3,  5,  3, 28,  5,  2, 34,  7, 14,  4, 12, 36,
          7,  9,  6,  8,  7,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
          0,  0,  0,  0],
        [ 3,  2,  8,  3,  7,  6,  5,  2,  4,  5,  5, 11,  3, 13,  5,  9, 12,  5,
          5,  2,  3,  2,  9,  2,  4,  5,  4,  4,  2,  3,  5,  4,  2,  2,  3,  4,
          3,  3, 14,  0,  0,  0,  0

In [7]:
# move to designated device
def batch_to_device(batch, device):
    return (
        batch[0].to(device),  # phoneme
        batch[1].to(device),  # speaker
        batch[2].to(device),  # input_lengths
        batch[3].to(device),  # mel
        batch[4].to(device),  # pitch
        batch[5].to(device),  # energy
        batch[6].to(device),  # duration
        batch[7].to(device),  # output_lengths
        batch[8],             # labels (strings)
        batch[9],             # wavs (file paths)
    )

### FastSpeech2 Model.

In [8]:
"""
Neural network modules for the FastSpeech 2: Fast and High-Quality End-to-End Text to Speech
synthesis model
Authors
* Sathvik Udupa 2022
* Pradnya Kandarkar 2023
* Yingzhi Wang 2023
"""

import numpy as np
import torch
import torch.nn.functional as F
from torch import nn
from torch.nn.modules.loss import _Loss

from speechbrain.lobes.models.transformer.Transformer import (
    PositionalEncoding,
    TransformerEncoder,
    get_key_padding_mask,
    get_mask_from_lengths,
)
from speechbrain.nnet import CNN, linear
from speechbrain.nnet.embedding import Embedding
from speechbrain.nnet.losses import bce_loss
from speechbrain.nnet.normalization import LayerNorm

from speechbrain.lobes.models.FastSpeech2 import (
    EncoderPreNet,
    DurationPredictor,
    PostNet,
    upsample,
    average_over_durations,
    SSIMLoss,
)




class FastSpeech2(nn.Module):
    """The FastSpeech2 text-to-speech model.
    This class is the main entry point for the model, which is responsible
    for instantiating all submodules, which, in turn, manage the individual
    neural network layers
    Simplified STRUCTURE: input->token embedding ->encoder ->duration/pitch/energy predictor ->duration
    upsampler -> decoder -> output
    During training, teacher forcing is used (ground truth durations are used for upsampling)

    Arguments
    ---------
    enc_num_layers: int
        number of transformer layers (TransformerEncoderLayer) in encoder
    enc_num_head: int
        number of multi-head-attention (MHA) heads in encoder transformer layers
    enc_d_model: int
        the number of expected features in the encoder
    enc_ffn_dim: int
        the dimension of the feedforward network model
    enc_k_dim: int
        the dimension of the key
    enc_v_dim: int
        the dimension of the value
    enc_dropout: float
        Dropout for the encoder
    dec_num_layers: int
        number of transformer layers (TransformerEncoderLayer) in decoder
    dec_num_head: int
        number of multi-head-attention (MHA) heads in decoder transformer layers
    dec_d_model: int
        the number of expected features in the decoder
    dec_ffn_dim: int
        the dimension of the feedforward network model
    dec_k_dim: int
        the dimension of the key
    dec_v_dim: int
        the dimension of the value
    dec_dropout: float
        dropout for the decoder
    normalize_before: bool
        whether normalization should be applied before or after MHA or FFN in Transformer layers.
    ffn_type: str
        whether to use convolutional layers instead of feed forward network inside transformer layer.
    ffn_cnn_kernel_size_list: list of int
        conv kernel size of 2 1d-convs if ffn_type is 1dcnn
    n_char: int
        the number of symbols for the token embedding
    n_mels: int
        number of bins in mel spectrogram
    postnet_embedding_dim: int
       output feature dimension for convolution layers
    postnet_kernel_size: int
       postnet convolution kernel size
    postnet_n_convolutions: int
       number of convolution layers
    postnet_dropout: float
        dropout probability for postnet
    padding_idx: int
        the index for padding
    dur_pred_kernel_size: int
        the convolution kernel size in duration predictor
    pitch_pred_kernel_size: int
        kernel size for pitch prediction.
    energy_pred_kernel_size: int
        kernel size for energy prediction.
    variance_predictor_dropout: float
        dropout probability for variance predictor (duration/pitch/energy)

    Example
    -------
    >>> import torch
    >>> from speechbrain.lobes.models.FastSpeech2 import FastSpeech2
    >>> model = FastSpeech2(
    ...    enc_num_layers=6,
    ...    enc_num_head=2,
    ...    enc_d_model=384,
    ...    enc_ffn_dim=1536,
    ...    enc_k_dim=384,
    ...    enc_v_dim=384,
    ...    enc_dropout=0.1,
    ...    dec_num_layers=6,
    ...    dec_num_head=2,
    ...    dec_d_model=384,
    ...    dec_ffn_dim=1536,
    ...    dec_k_dim=384,
    ...    dec_v_dim=384,
    ...    dec_dropout=0.1,
    ...    normalize_before=False,
    ...    ffn_type='1dcnn',
    ...    ffn_cnn_kernel_size_list=[9, 1],
    ...    n_char=40,
    ...    n_mels=80,
    ...    postnet_embedding_dim=512,
    ...    postnet_kernel_size=5,
    ...    postnet_n_convolutions=5,
    ...    postnet_dropout=0.5,
    ...    padding_idx=0,
    ...    dur_pred_kernel_size=3,
    ...    pitch_pred_kernel_size=3,
    ...    energy_pred_kernel_size=3,
    ...    variance_predictor_dropout=0.5)
    >>> inputs = torch.tensor([
    ...     [13, 12, 31, 14, 19],
    ...     [31, 16, 30, 31, 0],
    ... ])
    >>> input_lengths = torch.tensor([5, 4])
    >>> durations = torch.tensor([
    ...     [2, 4, 1, 5, 3],
    ...     [1, 2, 4, 3, 0],
    ... ])
    >>> mel_post, postnet_output, predict_durations, predict_pitch, avg_pitch, predict_energy, avg_energy, mel_lens = model(inputs, durations=durations)
    >>> mel_post.shape, predict_durations.shape
    (torch.Size([2, 15, 80]), torch.Size([2, 5]))
    >>> predict_pitch.shape, predict_energy.shape
    (torch.Size([2, 5, 1]), torch.Size([2, 5, 1]))
    """

    def __init__(
        self,
        # encoder parameters
        enc_num_layers,
        enc_num_head,
        enc_d_model,
        enc_ffn_dim,
        enc_k_dim,
        enc_v_dim,
        enc_dropout,
        # decoder parameters
        dec_num_layers,
        dec_num_head,
        dec_d_model,
        dec_ffn_dim,
        dec_k_dim,
        dec_v_dim,
        dec_dropout,
        normalize_before,
        ffn_type,
        ffn_cnn_kernel_size_list,
        n_char,
        n_mels,
        postnet_embedding_dim,
        postnet_kernel_size,
        postnet_n_convolutions,
        postnet_dropout,
        padding_idx,
        dur_pred_kernel_size,
        pitch_pred_kernel_size,
        energy_pred_kernel_size,
        variance_predictor_dropout,
    ):
        super().__init__()
        self.enc_num_head = enc_num_head
        self.dec_num_head = dec_num_head
        self.padding_idx = padding_idx
        self.sinusoidal_positional_embed_encoder = PositionalEncoding(
            enc_d_model
        )
        self.sinusoidal_positional_embed_decoder = PositionalEncoding(
            dec_d_model
        )

        self.speaker_emb = Embedding(
            num_embeddings=len(SPEAKERS),
            embedding_dim=enc_d_model,
            # padding_idx=padding_idx,
        )
        self.encPreNet = EncoderPreNet(
            n_char, padding_idx, out_channels=enc_d_model
        )
        self.durPred = DurationPredictor(
            in_channels=enc_d_model,
            out_channels=enc_d_model,
            kernel_size=dur_pred_kernel_size,
            dropout=variance_predictor_dropout,
        )
        self.pitchPred = DurationPredictor(
            in_channels=enc_d_model,
            out_channels=enc_d_model,
            kernel_size=dur_pred_kernel_size,
            dropout=variance_predictor_dropout,
        )
        self.energyPred = DurationPredictor(
            in_channels=enc_d_model,
            out_channels=enc_d_model,
            kernel_size=dur_pred_kernel_size,
            dropout=variance_predictor_dropout,
        )
        self.pitchEmbed = CNN.Conv1d(
            in_channels=1,
            out_channels=enc_d_model,
            kernel_size=pitch_pred_kernel_size,
            padding="same",
            skip_transpose=True,
        )

        self.energyEmbed = CNN.Conv1d(
            in_channels=1,
            out_channels=enc_d_model,
            kernel_size=energy_pred_kernel_size,
            padding="same",
            skip_transpose=True,
        )
        self.encoder = TransformerEncoder(
            num_layers=enc_num_layers,
            nhead=enc_num_head,
            d_ffn=enc_ffn_dim,
            d_model=enc_d_model,
            kdim=enc_k_dim,
            vdim=enc_v_dim,
            dropout=enc_dropout,
            activation=nn.ReLU,
            normalize_before=normalize_before,
            ffn_type=ffn_type,
            ffn_cnn_kernel_size_list=ffn_cnn_kernel_size_list,
        )

        self.decoder = TransformerEncoder(
            num_layers=dec_num_layers,
            nhead=dec_num_head,
            d_ffn=dec_ffn_dim,
            d_model=dec_d_model,
            kdim=dec_k_dim,
            vdim=dec_v_dim,
            dropout=dec_dropout,
            activation=nn.ReLU,
            normalize_before=normalize_before,
            ffn_type=ffn_type,
            ffn_cnn_kernel_size_list=ffn_cnn_kernel_size_list,
        )

        self.linear = linear.Linear(n_neurons=n_mels, input_size=dec_d_model)
        self.postnet = PostNet(
            n_mel_channels=n_mels,
            postnet_embedding_dim=postnet_embedding_dim,
            postnet_kernel_size=postnet_kernel_size,
            postnet_n_convolutions=postnet_n_convolutions,
            postnet_dropout=postnet_dropout,
        )


    def forward(
        self,
        tokens,
        speakers,
        durations=None,
        pitch=None,
        energy=None,
        pace=1.0,
        pitch_rate=1.0,
        energy_rate=1.0,
    ):
        """forward pass for training and inference

        Arguments
        ---------
        tokens: torch.Tensor
            batch of input tokens
        durations: torch.Tensor
            batch of durations for each token. If it is None, the model will infer on predicted durations
        pitch: torch.Tensor
            batch of pitch for each frame. If it is None, the model will infer on predicted pitches
        energy: torch.Tensor
            batch of energy for each frame. If it is None, the model will infer on predicted energies
        pace: float
            scaling factor for durations
        pitch_rate: float
            scaling factor for pitches
        energy_rate: float
            scaling factor for energies

        Returns
        -------
        mel_post: torch.Tensor
            mel outputs from the decoder
        postnet_output: torch.Tensor
            mel outputs from the postnet
        predict_durations: torch.Tensor
            predicted durations of each token
        predict_pitch: torch.Tensor
            predicted pitches of each token
        avg_pitch: torch.Tensor
            target pitches for each token if input pitch is not None
            None if input pitch is None
        predict_energy: torch.Tensor
            predicted energies of each token
        avg_energy: torch.Tensor
            target energies for each token if input energy is not None
            None if input energy is None
        mel_length:
            predicted lengths of mel spectrograms
        """
        srcmask = get_key_padding_mask(tokens, pad_idx=self.padding_idx)
        srcmask_inverted = (~srcmask).unsqueeze(-1)

        # prenet & encoder
        token_feats = self.encPreNet(tokens)
        pos = self.sinusoidal_positional_embed_encoder(token_feats)
        token_feats = torch.add(token_feats, pos) * srcmask_inverted
        attn_mask = (
            srcmask.unsqueeze(-1)
            .repeat(self.enc_num_head, 1, token_feats.shape[1])
            .permute(0, 2, 1)
            .bool()
        )
        token_feats, _ = self.encoder(
            token_feats, src_mask=attn_mask, src_key_padding_mask=srcmask
        )
        token_feats = token_feats * srcmask_inverted

        # ADD SPEAKER EMBEDDING -- modification.
        token_feats = token_feats + self.speaker_emb(speakers).unsqueeze(1).expand(
            -1, token_feats.shape[1], -1
        )

        # duration predictor
        predict_durations = self.durPred(token_feats, srcmask_inverted).squeeze(
            -1
        )

        if predict_durations.dim() == 1:
            predict_durations = predict_durations.unsqueeze(0)
        if durations is None:
            dur_pred_reverse_log = torch.clamp(
                torch.special.expm1(predict_durations), 0
            )

        # pitch predictor
        avg_pitch = None
        predict_pitch = self.pitchPred(token_feats, srcmask_inverted)
        # use a pitch rate to adjust the pitch
        predict_pitch = predict_pitch * pitch_rate
        if pitch is not None:
            avg_pitch = average_over_durations(pitch.unsqueeze(1), durations)
            pitch = self.pitchEmbed(avg_pitch)
            avg_pitch = avg_pitch.permute(0, 2, 1)
        else:
            pitch = self.pitchEmbed(predict_pitch.permute(0, 2, 1))
        pitch = pitch.permute(0, 2, 1)
        token_feats = token_feats.add(pitch)

        # energy predictor
        avg_energy = None
        predict_energy = self.energyPred(token_feats, srcmask_inverted)
        # use an energy rate to adjust the energy
        predict_energy = predict_energy * energy_rate
        if energy is not None:
            avg_energy = average_over_durations(energy.unsqueeze(1), durations)
            energy = self.energyEmbed(avg_energy)
            avg_energy = avg_energy.permute(0, 2, 1)
        else:
            energy = self.energyEmbed(predict_energy.permute(0, 2, 1))
        energy = energy.permute(0, 2, 1)
        token_feats = token_feats.add(energy)

        # upsamples the durations
        spec_feats, mel_lens = upsample(
            token_feats,
            durations if durations is not None else dur_pred_reverse_log,
            pace=pace,
        )
        srcmask = get_mask_from_lengths(torch.tensor(mel_lens))
        srcmask = srcmask.to(spec_feats.device)
        srcmask_inverted = (~srcmask).unsqueeze(-1)
        attn_mask = (
            srcmask.unsqueeze(-1)
            .repeat(self.dec_num_head, 1, spec_feats.shape[1])
            .permute(0, 2, 1)
            .bool()
        )

        # decoder
        pos = self.sinusoidal_positional_embed_decoder(spec_feats)
        spec_feats = torch.add(spec_feats, pos) * srcmask_inverted

        output_mel_feats, memory, *_ = self.decoder(
            spec_feats, src_mask=attn_mask, src_key_padding_mask=srcmask
        )

        # postnet
        mel_post = self.linear(output_mel_feats) * srcmask_inverted
        postnet_output = self.postnet(mel_post) + mel_post
        return (
            mel_post,
            postnet_output,
            predict_durations,
            predict_pitch,
            avg_pitch,
            predict_energy,
            avg_energy,
            torch.tensor(mel_lens),
        )








class Loss(nn.Module):
    """Loss Computation

    Arguments
    ---------
    log_scale_durations: bool
        applies logarithm to target durations
    ssim_loss_weight: float
        weight for ssim loss
    duration_loss_weight: float
        weight for the duration loss
    pitch_loss_weight: float
        weight for the pitch loss
    energy_loss_weight: float
        weight for the energy loss
    mel_loss_weight: float
        weight for the mel loss
    postnet_mel_loss_weight: float
        weight for the postnet mel loss
    spn_loss_weight: float
        weight for spn loss
    spn_loss_max_epochs: int
        Max number of epochs
    """

    def __init__(
        self,
        log_scale_durations,
        ssim_loss_weight,
        duration_loss_weight,
        pitch_loss_weight,
        energy_loss_weight,
        mel_loss_weight,
        postnet_mel_loss_weight,
        spn_loss_weight=1.0,
        spn_loss_max_epochs=8,
    ):
        super().__init__()

        self.ssim_loss = SSIMLoss()
        self.mel_loss = nn.MSELoss()
        self.postnet_mel_loss = nn.MSELoss()
        self.dur_loss = nn.MSELoss()
        self.pitch_loss = nn.MSELoss()
        self.energy_loss = nn.MSELoss()
        self.log_scale_durations = log_scale_durations
        self.ssim_loss_weight = ssim_loss_weight
        self.mel_loss_weight = mel_loss_weight
        self.postnet_mel_loss_weight = postnet_mel_loss_weight
        self.duration_loss_weight = duration_loss_weight
        self.pitch_loss_weight = pitch_loss_weight
        self.energy_loss_weight = energy_loss_weight
        self.spn_loss_weight = spn_loss_weight
        self.spn_loss_max_epochs = spn_loss_max_epochs


    def forward(self, predictions, targets, current_epoch):
        """Computes the value of the loss function and updates stats

        Arguments
        ---------
        predictions: tuple
            model predictions
        targets: tuple
            ground truth data
        current_epoch: int
            The count of the current epoch.

        Returns
        -------
        loss: torch.Tensor
            the loss value
        """
        (
            mel_target,
            target_durations,
            target_pitch,
            target_energy,
            mel_length,
            phon_len,
            # spn_labels,
        ) = targets
        assert len(mel_target.shape) == 3
        (
            mel_out,
            postnet_mel_out,
            log_durations,
            predicted_pitch,
            average_pitch,
            predicted_energy,
            average_energy,
            mel_lens,
            # spn_preds,
        ) = predictions

        predicted_pitch = predicted_pitch.squeeze(-1)
        predicted_energy = predicted_energy.squeeze(-1)

        target_pitch = average_pitch.squeeze(-1)
        target_energy = average_energy.squeeze(-1)

        log_durations = log_durations.squeeze(-1)
        if self.log_scale_durations:
            log_target_durations = torch.log1p(target_durations.float())
        # change this to perform batch level using padding mask

        for i in range(mel_target.shape[0]):
            if i == 0:
                mel_loss = self.mel_loss(
                    mel_out[i, : mel_length[i], :],
                    mel_target[i, : mel_length[i], :],
                )
                postnet_mel_loss = self.postnet_mel_loss(
                    postnet_mel_out[i, : mel_length[i], :],
                    mel_target[i, : mel_length[i], :],
                )
                dur_loss = self.dur_loss(
                    log_durations[i, : phon_len[i]],
                    log_target_durations[i, : phon_len[i]].to(torch.float32),
                )
                pitch_loss = self.pitch_loss(
                    predicted_pitch[i, : mel_length[i]],
                    target_pitch[i, : mel_length[i]].to(torch.float32),
                )
                energy_loss = self.energy_loss(
                    predicted_energy[i, : mel_length[i]],
                    target_energy[i, : mel_length[i]].to(torch.float32),
                )
            else:
                mel_loss = mel_loss + self.mel_loss(
                    mel_out[i, : mel_length[i], :],
                    mel_target[i, : mel_length[i], :],
                )
                postnet_mel_loss = postnet_mel_loss + self.postnet_mel_loss(
                    postnet_mel_out[i, : mel_length[i], :],
                    mel_target[i, : mel_length[i], :],
                )
                dur_loss = dur_loss + self.dur_loss(
                    log_durations[i, : phon_len[i]],
                    log_target_durations[i, : phon_len[i]].to(torch.float32),
                )
                pitch_loss = pitch_loss + self.pitch_loss(
                    predicted_pitch[i, : mel_length[i]],
                    target_pitch[i, : mel_length[i]].to(torch.float32),
                )
                energy_loss = energy_loss + self.energy_loss(
                    predicted_energy[i, : mel_length[i]],
                    target_energy[i, : mel_length[i]].to(torch.float32),
                )
        ssim_loss = self.ssim_loss(mel_out, mel_target, mel_length)
        mel_loss = torch.div(mel_loss, len(mel_target))
        postnet_mel_loss = torch.div(postnet_mel_loss, len(mel_target))
        dur_loss = torch.div(dur_loss, len(mel_target))
        pitch_loss = torch.div(pitch_loss, len(mel_target))
        energy_loss = torch.div(energy_loss, len(mel_target))

        # spn_loss = bce_loss(spn_preds, spn_labels)
        # if current_epoch > self.spn_loss_max_epochs:
        #     self.spn_loss_weight = 0

        total_loss = (
            ssim_loss * self.ssim_loss_weight
            + mel_loss * self.mel_loss_weight
            + postnet_mel_loss * self.postnet_mel_loss_weight
            + dur_loss * self.duration_loss_weight
            + pitch_loss * self.pitch_loss_weight
            + energy_loss * self.energy_loss_weight
            # + spn_loss * self.spn_loss_weight
        )

        loss = {
            "total_loss": total_loss,
            "ssim_loss": ssim_loss * self.ssim_loss_weight,
            "mel_loss": mel_loss * self.mel_loss_weight,
            "postnet_mel_loss": postnet_mel_loss * self.postnet_mel_loss_weight,
            "dur_loss": dur_loss * self.duration_loss_weight,
            "pitch_loss": pitch_loss * self.pitch_loss_weight,
            "energy_loss": energy_loss * self.energy_loss_weight,
            # "spn_loss": spn_loss * self.spn_loss_weight,
        }
        return loss

## Trainer

In [9]:
import yaml

# Load configuration from YAML file
config_path = os.path.join('/workspace/emo_rank_tts/params.yaml')
with open(config_path, 'r') as file:
    config = yaml.safe_load(file)

In [None]:
from IPython.display import clear_output
from collections import defaultdict

# misc
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# dataset
dataset = FastSpeech2Dataset(mode='train')
collate_fn = TextMelCollateWithAlignment()
dataloader = torch.utils.data.DataLoader(
    dataset,
    batch_size=8,
    shuffle=True,
    collate_fn=collate_fn,
    num_workers=4,
    pin_memory=True,
    drop_last=True,
)

# model
model = FastSpeech2(**config['fastspeech2']['model']).to(device)

# optimizer
optim = torch.optim.Adam(
    model.parameters(),
    lr=1e-4
)

# loss
criterion = Loss(**config['fastspeech2']['loss']).to(device)


global_step = 0
for epoch in range(1, 400):

    epoch_avg_loss = defaultdict(float)

    for idx, batch in enumerate(tqdm.notebook.tqdm(dataloader)):
        batch = batch_to_device(batch, device)
        phoneme, speakers, phon_len, mel_target, target_pitch, target_energy, target_duration, mel_length, labels, wavs = batch
        global_step += 1

        # Forward pass
        predictions = model(phoneme, speakers, target_duration, target_pitch, target_energy)


        # Compute loss
        targets = (mel_target, target_duration, target_pitch, target_energy, mel_length, phon_len)

        optim.zero_grad()
        loss = criterion(predictions, targets, epoch)
        loss['total_loss'].backward()
        optim.step()

        # Accumulate loss
        for loss_name, loss_value in loss.items():
            epoch_avg_loss[loss_name] += loss_value

        # Print predicted mels
        if idx == 0:
            melspecs = predictions[0].cpu().detach().numpy()
            y_melspecs = mel_target.cpu().detach().numpy()
            all_melspecs = np.concatenate((melspecs, y_melspecs), axis=0)
            fig, axes = plt.subplots(4, 4, figsize=(16, 10))
            for ax_idx, (ax, mel) in enumerate(zip(axes.flatten(), all_melspecs)):
                ax.imshow(mel.T, aspect='auto', origin='lower', interpolation='none')
                if ax_idx < len(melspecs):
                    label = f"Pred {ax_idx + 1}"
                    color = 'blue'
                else:
                    label = f"GT {ax_idx - len(melspecs) + 1}"
                    color = 'red'

                ax.text(
                    0.95, 0.95, label,
                    horizontalalignment='right',
                    verticalalignment='top',
                    transform=ax.transAxes,
                    fontsize=12,
                    fontweight='bold',
                    color=color,
                )
                # ax.set_title('Generated Mel-Spectrogram')
                # ax.axis('off')
            plt.tight_layout()
            plt.savefig('epoch_{}.png'.format(epoch))
            plt.close()

        # end of epoch
    
    epoch_avg_loss = {k: v / len(dataloader) for k, v in epoch_avg_loss.items()}
    print("=" * 50)
    print("Epoch: {}".format(epoch))
    print("=" * 50)
    for loss_name, loss_value in epoch_avg_loss.items():
        print("{:<30s}{:>20.4f}".format(loss_name, loss_value))
    print("=" * 50, '\n\n')

## Experiment: Intensity extractor test - `2025-06-05`

- <span style="color:red">fix #1</span>: Rank model의 intensity extractor 추정 **Done**
- <span style="color:red">fix #2</span>: Rank loss를 Rank model의 외부의 별도 class로 지정 **Done**
- <span style="color:red">fix #3</span>: Rank model의 output 수정: `H_i, H_j, h_i, h_j, r_i, r_j` 의 값을 반환하도록 설정 **Done**
- <span style="color:red">fix #4</span>: Rank model 별도의 intensity extractor class 생성 **Done**

---

- <span style="color:blue">imp #1</span>: intensity extractor의 output $\mathbf{I}$ 에 대해 phoneme-wise하게 평균을 취하여 크기 변경: [$B$, $T_{mel}$, $H$] $\rightarrow$ [$B$, $T_{phone}$, $H$]
- <span style="color:blue">imp #2</span>: Speaker ID를 사용할 것인지 speaker embedding을 사용할 것인지 실험을 통해 도출 (논문에서는 speaker id)
- <span style="color:blue">imp #3</span>: `phoneme_encoder_output`과 `intensity_representation`, `speaker_id` 를 concat하여 variance adaptor의 입력으로 feed.
- <span style="color:blue">imp #4</span>: 추론시에는 intensity_representation을 명시적으로 구할 수 없어, manual label을 사용 -> manual label을 구하기 위한 clustering 필요 ($N$-level averaging)

---

- Rank model 재학습 후, train dataset에 대한 intensity score 추출
- Intensity score을 $N$ 개로 bucketize (min - median - max).
- Speaker 별, emotion 별

In [None]:
rank_model = torch.load('/workspace/experiments/exp_3/best_model.pth')
intensity_extractor = rankm_model.intensity_extractor.to(device)

# imp#1
# -- textgrid? -> trimming 된 것 어떻게 처리할 것?
# intensity extractor의 결과 I의 time dimension이 phoneme sequence duration과 일치
# 각 phoneme sequence duration에 대해 I의 평균을 구한다.

start_idx = 0

averaged_intensity = []
for d in duration:
    phoneme_averaged = intensity[start_idx:start_idx + d].mean(dim=0)
    averaged_intensity.append(phoneme_averaged)
    start_idx += d






# fine_grained_emo_tts -> train dataset에서 manual intensity 추출출



In [32]:
result = np.load('/workspace/emo_rank_tts/rank_model/result.npz', allow_pickle=True)['intensity'].item()['bea']