In [None]:
import logging
import pathlib as pl
import sys
import time
import numpy as np
import torch
import torchaudio
import tqdm
from hyperpyyaml import load_hyperpyyaml
from torch.nn.parallel import DistributedDataParallel
from torch.utils.data import DataLoader, DistributedSampler, IterableDataset
from speechbrain.dataio.dataloader import LoopedLoader, SaveableDataLoader
import speechbrain as sb
from speechbrain.inference.ASR import EncoderDecoderASR
from speechbrain.inference.vocoders import UnitHIFIGAN,HIFIGAN
from enum import Enum, auto
from IPython.display import Audio, display

  from torchaudio.backend.common import AudioMetaData


In [2]:
class Stage(Enum):
    """Simple enum to track stage of experiments."""

    TRAIN = auto()
    VALID = auto()
    TEST = auto()

In [3]:
def dataio_prepare(hparams):
    """This function prepares the datasets to be used in the brain class.
    It also defines the data processing pipeline through user-defined functions.
    """
    codes_folder = pl.Path(hparams["codes_folder"])

    # Define audio pipeline. In this case, we simply read the audio contained
    # in the variable src_audio with the custom reader.
    @sb.utils.data_pipeline.takes("src_audio")
    @sb.utils.data_pipeline.provides("src_sig","src_mel")
    def src_audio_pipeline(wav):
        """Load the source language audio signal.
        This is done on the CPU in the `collate_fn`
        """
        info = torchaudio.info(wav)
        sig = sb.dataio.dataio.read_audio(wav)
        sig = torchaudio.transforms.Resample(
            info.sample_rate, hparams["sample_rate"]
        )(sig)
        mel_spec = hparams["mel_spectogram"](audio=sig).transpose(0,1)
        return sig, mel_spec

    @sb.utils.data_pipeline.takes("tgt_audio")
    @sb.utils.data_pipeline.provides("tgt_sig")
    def tgt_audio_pipeline(wav):
        """Load the target language audio signal.
        This is done on the CPU in the `collate_fn`.
        """
        info = torchaudio.info(wav)
        sig = sb.dataio.dataio.read_audio(wav)
        sig = torchaudio.transforms.Resample(
            info.sample_rate,
            hparams["sample_rate"],
        )(sig)
        return sig

    @sb.utils.data_pipeline.takes("id")
    @sb.utils.data_pipeline.provides("code_bos", "code_eos")
    def unit_pipeline(utt_id):
        """Load target codes"""
        code = np.load(codes_folder / f"{utt_id}_tgt.npy")
        code = torch.LongTensor(code)
        code = torch.unique_consecutive(code)
        code_bos = torch.cat((torch.LongTensor([hparams["bos_index"]]), code))
        yield code_bos
        code_eos = torch.cat((code, torch.LongTensor([hparams["eos_index"]])))
        yield code_eos

    datasets = {}
    for split in hparams["splits"]:
        datasets[split] = sb.dataio.dataset.DynamicItemDataset.from_json(
            json_path=hparams[f"{split}_json"],
            dynamic_items=[
                src_audio_pipeline,
                tgt_audio_pipeline,
                unit_pipeline,
            ],
            output_keys=[
                "id",
                "src_sig",
                "src_mel",
                "tgt_sig",
                "duration",
                "code_bos",
                "code_eos",
                "tgt_text",
            ],
        )

    # Sorting training data with ascending order makes the code  much
    # faster  because we minimize zero-padding. In most of the cases, this
    # does not harm the performance.
    if hparams["sorting"] == "ascending":
        datasets["train"] = datasets["train"].filtered_sorted(
            sort_key="duration"
        )
        datasets["valid"] = datasets["valid"].filtered_sorted(
            sort_key="duration"
        )

        hparams["train_dataloader_opts"]["shuffle"] = False
        hparams["valid_dataloader_opts"]["shuffle"] = False

    elif hparams["sorting"] == "descending":
        datasets["train"] = datasets["train"].filtered_sorted(
            sort_key="duration", reverse=True
        )
        datasets["valid"] = datasets["valid"].filtered_sorted(
            sort_key="duration", reverse=True
        )

        hparams["train_dataloader_opts"]["shuffle"] = False
        hparams["valid_dataloader_opts"]["shuffle"] = False

    elif hparams["sorting"] == "random":
        hparams["train_dataloader_opts"]["shuffle"] = True
        hparams["valid_dataloader_opts"]["shuffle"] = False

    else:
        raise NotImplementedError(
            "sorting must be random, ascending or descending"
        )

    # Dynamic Batching is used, we instantiate the needed samplers.

    return datasets

In [None]:
logger = logging.getLogger(__name__)


class S2UT(sb.core.Brain):
    def compute_forward(self, batch, stage):
        """Computes the forward pass.

        Arguments
        ---------
        batch : torch.Tensor or tensors
            An element from the dataloader, including inputs for processing.
        stage : Stage
            The stage of the experiment: Stage.TRAIN, Stage.VALID, Stage.TEST

        Returns
        -------
        (torch.Tensor or torch.Tensors, list of float or None, list of str or None)
            The outputs after all processing is complete.
        """
        batch = batch.to(self.device)
        wavs, wav_lens = batch.src_sig
        tokens_bos, _ = batch.code_bos

        # Use default padding value for wav2vec2
        wavs[wavs == self.hparams.pad_index] = 0.0
        # compute features
        enc_out = self.modules.wav2vec2(wavs, wav_lens)

        # dimensionality reduction
        enc_out = self.modules.enc(enc_out)

        if isinstance(self.modules.transformer, DistributedDataParallel):
            dec_out = self.modules.transformer.module.forward_mt_decoder_only(
                enc_out, tokens_bos, pad_idx=self.hparams.pad_index
            )
        else:
            dec_out = self.modules.transformer.forward_mt_decoder_only(
                enc_out, tokens_bos, pad_idx=self.hparams.pad_index
            )
        # logits and softmax
        pred = self.modules.seq_lin(dec_out)
        p_seq = self.hparams.log_softmax(pred)

        return (
            p_seq
        )

    def compute_objectives(self, p_seq, batch, stage):
        """Computes the loss given the predicted and targeted outputs.
        Arguments
        ---------
        predictions : torch.Tensor
            The model generated spectrograms and other metrics from `compute_forward`.
        batch : PaddedBatch
            This batch object contains all the relevant tensors for computation.
        stage : sb.Stage
            One of sb.Stage.TRAIN, sb.Stage.VALID, or sb.Stage.TEST.
        Returns
        -------
        loss : torch.Tensor
            A one-element tensor used for backpropagating the gradient.
        """
        tokens_eos, tokens_eos_lens = batch.code_eos

        # speech translation loss
        loss = self.hparams.seq_cost(p_seq, tokens_eos, length=tokens_eos_lens)

        return loss

    def init_optimizers(self):
        """Called during ``on_fit_start()``, initialize optimizers
        after parameters are fully configured (e.g. DDP, jit).
        """
        self.optimizers_dict = {}

        # Initializes the wav2vec2 optimizer if the model is not wav2vec2_frozen
        if not self.hparams.wav2vec2_frozen:
            self.wav2vec_optimizer = self.hparams.wav2vec_opt_class(
                self.modules.wav2vec2.parameters()
            )
            self.optimizers_dict["wav2vec_optimizer"] = self.wav2vec_optimizer

        self.model_optimizer = self.hparams.opt_class(
            self.hparams.model.parameters()
        )
        self.optimizers_dict["model_optimizer"] = self.model_optimizer

        if self.checkpointer is not None:
            self.checkpointer.add_recoverable(
                "wav2vec_optimizer", self.wav2vec_optimizer
            )
            self.checkpointer.add_recoverable(
                "model_optimizer", self.model_optimizer
            )

    def on_stage_start(self, stage, epoch):
        """Gets called when a stage starts.

        Arguments
        ---------
        stage : Stage
            The stage of the experiment: Stage.TRAIN, Stage.VALID, Stage.TEST
        epoch : int
            The current epoch count.

        Returns
        -------
        None
        """
        if stage != sb.Stage.TRAIN:

            self.acc_metric = self.hparams.acc_computer()
            self.bleu_metric = self.hparams.bleu_computer()
            self.last_batch = None

            logger.info("Loading pretrained HiFi-GAN ...")
            self.test_vocoder = UnitHIFIGAN.from_hparams(
                source=self.hparams.vocoder_source,
                savedir=self.hparams.vocoder_download_path,
                run_opts={"device": "cpu"},
            )

            logger.info("Loading pretrained ASR ...")
            self.test_asr = EncoderDecoderASR.from_hparams(
                source=self.hparams.asr_source,
                savedir=self.hparams.asr_download_path,
                run_opts={"device": "cpu"},
            )
            
            logger.info("Loading pretrained MEL-HIFI-GAN ...")
            self.mel_vocoder = HIFIGAN.from_hparams(
                source=self.hparams.mel_hifigan_source,
                savedir=self.hparams.mel_hifigan_download_path,
                run_opts={"device": "cpu"},
            )
            
            logger.info("Loading pretrained speaker_adapter ...")
            self.speaker_adapter = self.hparams.speaker_adapter.eval()
            self.speaker_adapter.load_state_dict(torch.load(self.hparams.speaker_adapter_source))
            
            logger.info("Loading pretrained var_predictor ...")
            self.var_predictor = self.hparams.var_predictor.eval()
            self.var_predictor.load_state_dict(torch.load(self.hparams.var_predictor_source))
            
 

In [None]:
hparams_file = "hparams/inference.yaml"
with open(hparams_file) as fin:
        hparams = load_hyperpyyaml(fin)

    # If distributed_launch=True then
    # create ddp_group with the right communication protocol

    # Create experiment directory
sb.create_experiment_directory(
        experiment_directory=hparams["output_folder"],
        hyperparams_to_save=hparams_file,
    )

datasets = dataio_prepare(hparams)

s2ut_brain = S2UT(
    modules=hparams["modules"],
    opt_class=hparams["opt_class"],
    hparams=hparams,
    checkpointer=hparams["checkpointer"],
)

s2ut_brain.on_evaluate_start(max_key="BLEU")
s2ut_brain.on_stage_start(Stage.TEST, epoch=None)


Some weights of Wav2Vec2Model were not initialized from the model checkpoint at facebook/wav2vec2-large-es-voxpopuli and are newly initialized: ['wav2vec2.encoder.pos_conv_embed.conv.parametrizations.weight.original0', 'wav2vec2.encoder.pos_conv_embed.conv.parametrizations.weight.original1']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


speechbrain.core - Beginning experiment!
speechbrain.core - Experiment folder: results/s2ut/888


  WeightNorm.apply(module, name, dim)


In [None]:
test_set = datasets["test"]
test_dataloader_opts = {
        "batch_size": 1,
    }
test_loader_kwargs=test_dataloader_opts
if not (
    isinstance(test_set, DataLoader)
    or isinstance(test_set, LoopedLoader)
):
    test_loader_kwargs["ckpt_prefix"] = None
    test_set = s2ut_brain.make_dataloader(
        test_set, Stage.TEST, **test_loader_kwargs
    )

In [None]:
s2ut_brain.modules.eval()
cvsst_wavs = []
s2ut_wavs = []
src_wavs = []
id_list = []
scs2ut_wavs_list = []

with torch.no_grad():
    for index, batch in enumerate(
        test_set,
    ):
        s2ut_brain.step += 1
        
        batch = batch.to(s2ut_brain.device)
        wavs, wav_lens = batch.src_sig
        tgt_wav, tgt_wav_lens = batch.tgt_sig
        src_wav, src_wav_lens = batch.src_sig
        tokens_bos, _ = batch.code_bos
        src_mel, src_mel_lens = batch.src_mel
        src_mel[src_mel == s2ut_brain.hparams.pad_index] = 0
        src_enc_out = s2ut_brain.speaker_adapter[0](src_mel.cpu())
        # Use default padding value for wav2vec2
        wavs[wavs == s2ut_brain.hparams.pad_index] = 0.0
        print("-----")
        # compute features
        enc_out = s2ut_brain.modules.wav2vec2(wavs, wav_lens)

        # dimensionality reduction
        enc_out = s2ut_brain.modules.enc(enc_out)

        if isinstance(s2ut_brain.modules.transformer, DistributedDataParallel):
            dec_out = s2ut_brain.modules.transformer.module.forward_mt_decoder_only(
                enc_out, tokens_bos, pad_idx=s2ut_brain.hparams.pad_index
            )
        else:
            dec_out = s2ut_brain.modules.transformer.forward_mt_decoder_only(
                enc_out, tokens_bos, pad_idx=s2ut_brain.hparams.pad_index
            )
        
        # logits and softmax

        pred = s2ut_brain.modules.seq_lin(dec_out)

        p_seq = s2ut_brain.hparams.log_softmax(pred)
        hyps = None
        wavs = None
        transcripts = None
        scs2ut_wavs = None
        
        ids = batch.id
        tgt_text = batch.tgt_text

        search = (
            s2ut_brain.hparams.valid_search
            if Stage.TEST == sb.Stage.VALID
            else s2ut_brain.hparams.test_search
            )
        hyps, _, _, _ = search(enc_out.detach(), wav_lens)

        # generate s2ut waveform
        for hyp in hyps:
            hyp = [x for x in hyp if x not in {100, 101, 102}]
            if len(hyp) > 3:
                
                code = torch.LongTensor(hyp)
                wav = s2ut_brain.test_vocoder.decode_unit(code)

        # generete scs2ut waveform
        for hyp in hyps:
            hyp = [x for x in hyp if x not in {100, 101, 102}]
            if len(hyp) > 3:
                
                code = torch.LongTensor(hyp)
                
                code = s2ut_brain.var_predictor(code)
                fft_out = s2ut_brain.speaker_adapter[1](
                    code, 
                    pad_idx=s2ut_brain.hparams.pad_index
                )

                pred_mel = s2ut_brain.speaker_adapter[2](src_enc_out, fft_out)
                pred_mel = pred_mel.transpose(1,2)
                scs2ut_wav = s2ut_brain.mel_vocoder.decode_batch(pred_mel).squeeze(1)

        id_list.append(ids[0])
        tgt_wavs.append(tgt_wav)
        src_wavs.append(src_wav)
        s2ut_wavs.append(wav)
        scs2ut_wavs_list.append(scs2ut_wav)
        
        

In [None]:
for i in range(len(tgt_wavs)):
    print("----")
    print("src_audio: ")
    display(Audio(data=src_wavs[i], rate=16000))
    print("cvsst_audio: ")
    display(Audio(data=tgt_wavs[i], rate=16000))
    print("s2ut_audio: ")
    display(Audio(data=s2ut_wavs[i], rate=16000))
    print("scs2ut_audio: ")
    display(Audio(data=scs2ut_wavs_list[i], rate=16000))

    