In [None]:
import os, sys
sys.path.append("../")
from abc import ABC

import torch

from src.tts.layers.VariancePredictor import VariancePredictor
from src.utility.utils import make_pad_mask

In [None]:
class StyleAdaptor(torch.nn.Module):
    """
    Style Adaptor module
    
    This is a module which same as variance adaptor in FastSpeech2 
    without having Duration Predictor and Length Regulator
    """
    
    def __init__(
        self,
        adim=384,
        # pitch predictor
        pitch_predictor_layers=5,
        pitch_predictor_chans=256,
        pitch_predictor_kernel_size=5,
        pitch_predictor_dropout=0.5,
        pitch_embed_kernel_size=1,
        pitch_embed_dropout=0.0,
        # energy predictor
        energy_predictor_layers=2,
        energy_predictor_chans=256,
        energy_predictor_kernel_size=3,
        energy_predictor_dropout=0.5,
        energy_embed_kernel_size=1,
        energy_embed_dropout=0.0,
    ):
        super().__init__()
        
        # define pitch predictor
        self.pitch_predictor = VariancePredictor(
            idim=adim,
            n_layers=pitch_predictor_layers,
            n_chans=pitch_predictor_chans,
            kernel_size=pitch_predictor_kernel_size,
            dropout_rate=pitch_predictor_dropout,
        )
        # continuous pitch + FastPitch style avg
        self.pitch_embed = torch.nn.Sequential(
            torch.nn.Conv1d(
                in_channels=1,
                out_channels=adim,
                kernel_size=pitch_embed_kernel_size,
                padding=(pitch_embed_kernel_size - 1) // 2,
            ),
            torch.nn.Dropout(pitch_embed_dropout),
        )

        # define energy predictor
        self.energy_predictor = VariancePredictor(
            idim=adim,
            n_layers=energy_predictor_layers,
            n_chans=energy_predictor_chans,
            kernel_size=energy_predictor_kernel_size,
            dropout_rate=energy_predictor_dropout,
        )
        # continuous energy + FastPitch style avg
        self.energy_embed = torch.nn.Sequential(
            torch.nn.Conv1d(
                in_channels=1,
                out_channels=adim,
                kernel_size=energy_embed_kernel_size,
                padding=(energy_embed_kernel_size - 1) // 2,
            ),
            torch.nn.Dropout(energy_embed_dropout),
        )
        
    def forward(
        self, 
        xs, 
        padding_mask=None,
        gold_pitch=None,
        gold_energy=None,
        is_inference=False
    ):
        pitch_predictions = self.pitch_predictor(xs, padding_mask)
        energy_predictions = self.energy_predictor(xs, padding_mask)
        
        if is_inference:
            embedded_pitch_curve = self.pitch_embed(
                pitch_predictions.transpose(1, 2)
            ).transpose(1,2)  # (B, Tmax, adim)
            embedded_energy_curve = self.energy_embed(
                energy_predictions.transpose(1, 2)
            ).transpose(1, 2)  # (B, Tmax, adim)
            
            
        else:
            embedded_pitch_curve = self.pitch_embed(
                gold_pitch.transpose(1, 2)
            ).transpose(1,2)  # (B, Tmax, adim)
            embedded_energy_curve = self.energy_embed(
                gold_energy.transpose(1, 2)
            ).transpose(1, 2)  # (B, Tmax, adim)
            
        embedded_curve = embedded_pitch_curve + embedded_energy_curve
        
        return (
            embedded_curve,
            pitch_predictions,
            energy_predictions
        )

In [None]:
style_adaptor = StyleAdaptor()

In [None]:
B, Tmax, adim = 2, 5, 384

In [None]:
encoder_texts = torch.rand(B, Tmax, adim)
d_masks = make_pad_mask([3, 5])

In [None]:
embedded_curve, pitch_predictions, energy_predictions = style_adaptor(
    encoder_texts.detach(), d_masks.unsqueeze(-1)
)

In [None]:
print(embedded_curve.shape)
print(pitch_predictions.shape)
print(energy_predictions.shape)