<a href="https://colab.research.google.com/github/Ranamoeed/CodeBankVC/blob/main/Untitled0.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import sys
import time
from dataclasses import dataclass, field
from typing import Dict, List, Tuple

import numpy as np
import torch
import torch.nn.functional as F
from coqpit import Coqpit
from torch import nn
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler

from TTS.tts.utils.visual import plot_spectrogram
from TTS.utils.audio import AudioProcessor
from TTS.utils.audio.numpy_transforms import mulaw_decode
from TTS.utils.io import load_fsspec
from TTS.vocoder.datasets.wavernn_dataset import WaveRNNDataset
from TTS.vocoder.layers.losses import WaveRNNLoss
from TTS.vocoder.models.base_vocoder import BaseVocoder
from TTS.vocoder.utils.distribution import sample_from_discretized_mix_logistic, sample_from_gaussian

def stream(string, variables):
    sys.stdout.write(f"\r{string}" % variables)

class ResBlock(nn.Module):
    def __init__(self, dims):
        super().__init__()
        self.conv1 = nn.Conv1d(dims, dims, kernel_size=1, bias=False)
        self.conv2 = nn.Conv1d(dims, dims, kernel_size=1, bias=False)
        self.batch_norm1 = nn.BatchNorm1d(dims)
        self.batch_norm2 = nn.BatchNorm1d(dims)
    def forward(self, x):
        residual = x
        x = self.conv1(x)
        x = self.batch_norm1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = self.batch_norm2(x)
        return x + residual

class MelResNet(nn.Module):
    def __init__(self, num_res_blocks, in_dims, compute_dims, res_out_dims, pad):
        super().__init__()
        k_size = pad * 2 + 1
        self.conv_in = nn.Conv1d(in_dims, compute_dims, kernel_size=k_size, bias=False)
        self.batch_norm = nn.BatchNorm1d(compute_dims)
        self.layers = nn.ModuleList()
        for _ in range(num_res_blocks):
            self.layers.append(ResBlock(compute_dims))
        self.conv_out = nn.Conv1d(compute_dims, res_out_dims, kernel_size=1)
    def forward(self, x):
        x = self.conv_in(x)
        x = self.batch_norm(x)
        x = F.relu(x)
        for f in self.layers:
            x = f(x)
        x = self.conv_out(x)
        return x

class Stretch2d(nn.Module):
    def __init__(self, x_scale, y_scale):
        super().__init__()
        self.x_scale = x_scale
        self.y_scale = y_scale
    def forward(self, x):
        b, c, h, w = x.size()
        x = x.unsqueeze(-1).unsqueeze(3)
        x = x.repeat(1, 1, 1, self.y_scale, 1, self.x_scale)
        return x.view(b, c, h * self.y_scale, w * self.x_scale)

class UpsampleNetwork(nn.Module):
    def __init__(
        self,
        feat_dims,
        upsample_scales,
        compute_dims,
        num_res_blocks,
        res_out_dims,
        pad,
        use_aux_net,
    ):
        super().__init__()
        self.total_scale = np.cumproduct(upsample_scales)[-1]
        self.indent = pad * self.total_scale
        self.use_aux_net = use_aux_net
        if use_aux_net:
            self.resnet = MelResNet(num_res_blocks, feat_dims, compute_dims, res_out_dims, pad)
            self.resnet_stretch = Stretch2d(self.total_scale, 1)
        self.up_layers = nn.ModuleList()
        for scale in upsample_scales:
            k_size = (1, scale * 2 + 1)
            padding = (0, scale)
            stretch = Stretch2d(scale, 1)
            conv = nn.Conv2d(1, 1, kernel_size=k_size, padding=padding, bias=False)
            conv.weight.data.fill_(1.0 / k_size[1])
            self.up_layers.append(stretch)
            self.up_layers.append(conv)
    def forward(self, m):
        if self.use_aux_net:
            aux = self.resnet(m).unsqueeze(1)
            aux = self.resnet_stretch(aux)
            aux = aux.squeeze(1)
            aux = aux.transpose(1, 2)
        else:
            aux = None
        m = m.unsqueeze(1)
        for f in self.up_layers:
            m = f(m)
        m = m.squeeze(1)[:, :, self.indent : -self.indent]
        return m.transpose(1, 2), aux

class Upsample(nn.Module):
    def __init__(self, scale, pad, num_res_blocks, feat_dims, compute_dims, res_out_dims, use_aux_net):
        super().__init__()
        self.scale = scale
        self.pad = pad
        self.indent = pad * scale
        self.use_aux_net = use_aux_net
        self.resnet = MelResNet(num_res_blocks, feat_dims, compute_dims, res_out_dims, pad)
    def forward(self, m):
        if self.use_aux_net:
            aux = self.resnet(m)
            aux = torch.nn.functional.interpolate(aux, scale_factor=self.scale, mode="linear", align_corners=True)
            aux = aux.transpose(1, 2)
        else:
            aux = None
        m = torch.nn.functional.interpolate(m, scale_factor=self.scale, mode="linear", align_corners=True)
        m = m[:, :, self.indent : -self.indent]
        m = m * 0.045
        return m.transpose(1, 2), aux

@dataclass
class WavernnArgs(Coqpit):
    rnn_dims: int = 512
    fc_dims: int = 512
    compute_dims: int = 128
    res_out_dims: int = 128
    num_res_blocks: int = 10
    use_aux_net: bool = True
    use_upsample_net: bool = True
    upsample_factors: List[int] = field(default_factory=lambda: [4, 8, 8])
    mode: str = "mold"
    mulaw: bool = True
    pad: int = 2
    feat_dims: int = 80

class Wavernn(BaseVocoder):
    def __init__(self, config: Coqpit):
        super().__init__(config)
        if isinstance(self.args.mode, int):
            self.n_classes = 2**self.args.mode
        elif self.args.mode == "mold":
            self.n_classes = 3 * 10
        elif self.args.mode == "gauss":
            self.n_classes = 2
        else:
            raise RuntimeError("Unknown model mode value - ", self.args.mode)
        self.ap = AudioProcessor(**config.audio.to_dict())
        self.aux_dims = self.args.res_out_dims // 4
        if self.args.use_upsample_net:
            assert (
                np.cumproduct(self.args.upsample_factors)[-1] == config.audio.hop_length
            ), " [!] upsample scales needs to be equal to hop_length"
            self.upsample = UpsampleNetwork(
                self.args.feat_dims,
                self.args.upsample_factors,
                self.args.compute_dims,
                self.args.num_res_blocks,
                self.args.res_out_dims,
                self.args.pad,
                self.args.use_aux_net,
            )
        else:
            self.upsample = Upsample(
                config.audio.hop_length,
                self.args.pad,
                self.args.num_res_blocks,
                self.args.feat_dims,
                self.args.compute_dims,
                self.args.res_out_dims,
                self.args.use_aux_net,
            )
        if self.args.use_aux_net:
            self.I = nn.Linear(self.args.feat_dims + self.aux_dims + 1, self.args.rnn_dims)
            self.rnn1 = nn.GRU(self.args.rnn_dims, self.args.rnn_dims, batch_first=True)
            self.rnn2 = nn.GRU(self.args.rnn_dims + self.aux_dims, self.args.rnn_dims, batch_first=True)
            self.fc1 = nn.Linear(self.args.rnn_dims + self.aux_dims, self.args.fc_dims)
            self.fc2 = nn.Linear(self.args.fc_dims + self.aux_dims, self.args.fc_dims)
            self.fc3 = nn.Linear(self.args.fc_dims, self.n_classes)
        else:
            self.I = nn.Linear(self.args.feat_dims + 1, self.args.rnn_dims)
            self.rnn1 = nn.GRU(self.args.rnn_dims, self.args.rnn_dims, batch_first=True)
            self.rnn2 = nn.GRU(self.args.rnn_dims, self.args.rnn_dims, batch_first=True)
            self.fc1 = nn.Linear(self.args.rnn_dims, self.args.fc_dims)
            self.fc2 = nn.Linear(self.args.fc_dims, self.args.fc_dims)
            self.fc3 = nn.Linear(self.args.fc_dims, self.n_classes)
    def forward(self, x, mels):
        bsize = x.size(0)
        h1 = torch.zeros(1, bsize, self.args.rnn_dims).to(x.device)
        h2 = torch.zeros(1, bsize, self.args.rnn_dims).to(x.device)
        mels, aux = self.upsample(mels)
        if self.args.use_aux_net:
            aux_idx = [self.aux_dims * i for i in range(5)]
            a1 = aux[:, :, aux_idx[0] : aux_idx[1]]
            a2 = aux[:, :, aux_idx[1] : aux_idx[2]]
            a3 = aux[:, :, aux_idx[2] : aux_idx[3]]
            a4 = aux[:, :, aux_idx[3] : aux_idx[4]]
        x = (
            torch.cat([x.unsqueeze(-1), mels, a1], dim=2)
            if self.args.use_aux_net
            else torch.cat([x.unsqueeze(-1), mels], dim=2)
        )
        x = self.I(x)
        res = x
        self.rnn1.flatten_parameters()
        x, _ = self.rnn1(x, h1)
        x = x + res
        res = x
        x = torch.cat([x, a2], dim=2) if self.args.use_aux_net else x
        self.rnn2.flatten_parameters()
        x, _ = self.rnn2(x, h2)
        x = x + res
        x = torch.cat([x, a3], dim=2) if self.args.use_aux_net else x
        x = F.relu(self.fc1(x))
        x = torch.cat([x, a4], dim=2) if self.args.use_aux_net else x
        x = F.relu(self.fc2(x))
        return self.fc3(x)

    def inference(self, mels, batched=None, target=None, overlap=None):
        self.eval()
        output = []
        start = time.time()
        rnn1 = self.get_gru_cell(self.rnn1)
        rnn2 = self.get_gru_cell(self.rnn2)
        with torch.no_grad():
            if isinstance(mels, np.ndarray):
                mels = torch.FloatTensor(mels).to(str(next(self.parameters()).device))
            if mels.ndim == 2:
                mels = mels.unsqueeze(0)
            wave_len = (mels.size(-1) - 1) * self.config.audio.hop_length
            mels = self.pad_tensor(mels.transpose(1, 2), pad=self.args.pad, side="both")
            mels, aux = self.upsample(mels.transpose(1, 2))
            if batched:
                mels = self.fold_with_overlap(mels, target, overlap)
                if aux is not None:
                    aux = self.fold_with_overlap(aux, target, overlap)
            b_size, seq_len, _ = mels.size()
            h1 = torch.zeros(b_size, self.args.rnn_dims).type_as(mels)
            h2 = torch.zeros(b_size, self.args.rnn_dims).type_as(mels)
            x = torch.zeros(b_size, 1).type_as(mels)
            if self.args.use_aux_net:
                d = self.aux_dims
                aux_split = [aux[:, :, d * i : d * (i + 1)] for i in range(4)]
            for i in range(seq_len):
                m_t = mels[:, i, :]
                if self.args.use_aux_net:
                    a1_t, a2_t, a3_t, a4_t = (a[:, i, :] for a in aux_split)
                x = torch.cat([x, m_t, a1_t], dim=1) if self.args.use_aux_net else torch.cat([x, m_t], dim=1)
                x = self.I(x)
                h1 = rnn1(x, h1)
                x = x + h1
                inp = torch.cat([x, a2_t], dim=1) if self.args.use_aux_net else x
                h2 = rnn2(inp, h2)
                x = x + h2
                x = torch.cat([x, a3_t], dim=1) if self.args.use_aux_net else x
                x = F.relu(self.fc1(x))
                x = torch.cat([x, a4_t], dim=1) if self.args.use_aux_net else x
                x = F.relu(self.fc2(x))
                logits = self.fc3(x)
                if self.args.mode == "mold":
                    sample = sample_from_discretized_mix_logistic(logits.unsqueeze(0).transpose(1, 2))
                    output.append(sample.view(-1))
                    x = sample.transpose(0, 1).type_as(mels)
                elif self.args.mode == "gauss":
                    sample = sample_from_gaussian(logits.unsqueeze(0).transpose(1, 2))
                    output.append(sample.view(-1))
                    x = sample.transpose(0, 1).type_as(mels)
                elif isinstance(self.args.mode, int):
                    posterior = F.softmax(logits, dim=1)
                    distrib = torch.distributions.Categorical(posterior)
                    sample = 2 * distrib.sample().float() / (self.n_classes - 1.0) - 1.0
                    output.append(sample)
                    x = sample.unsqueeze(-1)
                else:
                    raise RuntimeError("Unknown model mode value - ", self.args.mode)
                if i % 100 == 0:
                    self.gen_display(i, seq_len, b_size, start)
        output = torch.stack(output).transpose(0, 1)
        output = output.cpu()
        if batched:
            output = output.numpy()
            output = output.astype(np.float64)
            output = self.xfade_and_unfold(output, target, overlap)
        else:
            output = output[0]
        if self.args.mulaw and isinstance(self.args.mode, int):
            output = mulaw_decode(wav=output, mulaw_qc=self.args.mode)
        fade_out = np.linspace(1, 0, 20 * self.config.audio.hop_length)
        output = output[:wave_len]
        if wave_len > len(fade_out):
            output[-20 * self.config.audio.hop_length :] *= fade_out
        self.train()
        return output

    def gen_display(self, i, seq_len, b_size, start):
        gen_rate = (i + 1) / (time.time() - start) * b_size / 1000
        realtime_ratio = gen_rate * 1000 / self.config.audio.sample_rate
        stream(
            "%i/%i -- batch_size: %i -- gen_rate: %.1f kHz -- x_realtime: %.1f  ",
            (i * b_size, seq_len * b_size, b_size, gen_rate, realtime_ratio),
        )

    def fold_with_overlap(self, x, target, overlap):
        _, total_len, features = x.size()
        num_folds = (total_len - overlap) // (target + overlap)
        extended_len = num_folds * (overlap + target) + overlap
        remaining = total_len - extended_len
        if remaining != 0:
            num_folds += 1
            padding = target + 2 * overlap - remaining
            x = self.pad_tensor(x, padding, side="after")
        folded = torch.zeros(num_folds, target + 2 * overlap, features).to(x.device)
        for i in range(num_folds):
            start = i * (target + overlap)
            end = start + target + 2 * overlap
            folded[i] = x[:, start:end, :]
        return folded

    @staticmethod
    def get_gru_cell(gru):
        gru_cell = nn.GRUCell(gru.input_size, gru.hidden_size)
        gru_cell.weight_hh.data = gru.weight_hh_l0.data
        gru_cell.weight_ih.data = gru.weight_ih_l0.data
        gru_cell.bias_hh.data = gru.bias_hh_l0.data
        gru_cell.bias_ih.data = gru.bias_ih_l0.data
        return gru_cell

    @staticmethod
    def pad_tensor(x, pad, side="both"):
        b, t, c = x.size()
        total = t + 2 * pad if side == "both" else t + pad
        padded = torch.zeros(b, total, c).to(x.device)
        if side in ("before", "both"):
            padded[:, pad : pad + t, :] = x
        elif side == "after":
            padded[:, :t, :] = x
        return padded

    @staticmethod
    def xfade_and_unfold(y, target, overlap):
        num_folds, length = y.shape
        target = length - 2 * overlap
        total_len = num_folds * (target + overlap) + overlap
        silence_len = overlap // 2
        fade_len = overlap - silence_len
        silence = np.zeros((silence_len), dtype=np.float64)
        t = np.linspace(-1, 1, fade_len, dtype=np.float64)
        fade_in = np.sqrt(0.5 * (1 + t))
        fade_out = np.sqrt(0.5 * (1 - t))
        fade_in = np.concatenate([silence, fade_in])
        fade_out = np.concatenate([fade_out, silence])
        y[:, :overlap] *= fade_in
        y[:, -overlap:] *= fade_out
        unfolded = np.zeros((total_len), dtype=np.float64)
        for i in range(num_folds):
            start = i * (target + overlap)
            end = start + target + 2 * overlap
            unfolded[start:end] += y[i]
        return unfolded

    def load_checkpoint(
        self, config, checkpoint_path, eval=False, cache=False
    ):
        state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), cache=cache)
        self.load_state_dict(state["model"])
        if eval:
            self.eval()
            assert not self.training
        def train_step(self, batch: Dict, criterion: Dict) -> Tuple[Dict, Dict]:
        mels = batch["input"]
        waveform = batch["waveform"]
        waveform_coarse = batch["waveform_coarse"]

        y_hat = self.forward(waveform, mels)
        if isinstance(self.args.mode, int):
            y_hat = y_hat.transpose(1, 2).unsqueeze(-1)
        else:
            waveform_coarse = waveform_coarse.float()
        waveform_coarse = waveform_coarse.unsqueeze(-1)
        loss_dict = criterion(y_hat, waveform_coarse)
        return {"model_output": y_hat}, loss_dict

    def eval_step(self, batch: Dict, criterion: Dict) -> Tuple[Dict, Dict]:
        return self.train_step(batch, criterion)

    @torch.no_grad()
    def test(
        self, assets: Dict, test_loader: "DataLoader", output: Dict
    ) -> Tuple[Dict, Dict]:
        ap = self.ap
        figures = {}
        audios = {}
        samples = test_loader.dataset.load_test_samples(1)
        for idx, sample in enumerate(samples):
            x = torch.FloatTensor(sample[0])
            x = x.to(next(self.parameters()).device)
            y_hat = self.inference(x, self.config.batched, self.config.target_samples, self.config.overlap_samples)
            x_hat = ap.melspectrogram(y_hat)
            figures.update(
                {
                    f"test_{idx}/ground_truth": plot_spectrogram(x.T),
                    f"test_{idx}/prediction": plot_spectrogram(x_hat.T),
                }
            )
            audios.update({f"test_{idx}/audio": y_hat})
        return figures, audios

    def test_log(
        self, outputs: Dict, logger: "Logger", assets: Dict, steps: int
    ) -> Tuple[Dict, np.ndarray]:
        figures, audios = outputs
        logger.eval_figures(steps, figures)
        logger.eval_audios(steps, audios, self.ap.sample_rate)

    @staticmethod
    def format_batch(batch: Dict) -> Dict:
        waveform = batch[0]
        mels = batch[1]
        waveform_coarse = batch[2]
        return {"input": mels, "waveform": waveform, "waveform_coarse": waveform_coarse}

    def get_data_loader(
        self,
        config: Coqpit,
        assets: Dict,
        is_eval: True,
        samples: List,
        verbose: bool,
        num_gpus: int,
    ):
        ap = self.ap
        dataset = WaveRNNDataset(
            ap=ap,
            items=samples,
            seq_len=config.seq_len,
            hop_len=ap.hop_length,
            pad=config.model_args.pad,
            mode=config.model_args.mode,
            mulaw=config.model_args.mulaw,
            is_training=not is_eval,
            verbose=verbose,
        )
        sampler = DistributedSampler(dataset, shuffle=True) if num_gpus > 1 else None
        loader = DataLoader(
            dataset,
            batch_size=1 if is_eval else config.batch_size,
            shuffle=num_gpus == 0,
            collate_fn=dataset.collate,
            sampler=sampler,
            num_workers=config.num_eval_loader_workers if is_eval else config.num_loader_workers,
            pin_memory=True,
        )
        return loader

    def get_criterion(self):
        return WaveRNNLoss(self.args.mode)

    @staticmethod
    def init_from_config(config: "WavernnConfig"):
        return Wavernn(config)

In [None]:
from inspect import signature
from typing import Dict, List, Tuple

import numpy as np
import torch
from coqpit import Coqpit
from torch import nn
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from trainer.trainer_utils import get_optimizer, get_scheduler

from TTS.utils.audio import AudioProcessor
from TTS.utils.io import load_fsspec
from TTS.vocoder.datasets.gan_dataset import GANDataset
from TTS.vocoder.layers.losses import DiscriminatorLoss, GeneratorLoss
from TTS.vocoder.models import setup_discriminator, setup_generator
from TTS.vocoder.models.base_vocoder import BaseVocoder
from TTS.vocoder.utils.generic_utils import plot_results


class GAN(BaseVocoder):
    def __init__(self, config: Coqpit, ap: AudioProcessor = None):
        super().__init__(config)
        self.config = config
        self.model_g = setup_generator(config)
        self.model_d = setup_discriminator(config)
        self.train_disc = False
        self.y_hat_g = None
        self.ap = ap

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.model_g.forward(x)

    def inference(self, x: torch.Tensor) -> torch.Tensor:
        return self.model_g.inference(x)

    def train_step(self, batch: Dict, criterion: Dict, optimizer_idx: int) -> Tuple[Dict, Dict]:
        outputs = {}
        loss_dict = {}
        x = batch["input"]
        y = batch["waveform"]
        if optimizer_idx not in [0, 1]:
            raise ValueError(" [!] Unexpected `optimizer_idx`.")
        if optimizer_idx == 0:
            y_hat = self.model_g(x)[:, :, : y.size(2)]
            self.y_hat_g = y_hat
            self.y_hat_sub = None
            self.y_sub_g = None
            if y_hat.shape[1] > 1:
                self.y_hat_sub = y_hat
                y_hat = self.model_g.pqmf_synthesis(y_hat)
                self.y_hat_g = y_hat
                self.y_sub_g = self.model_g.pqmf_analysis(y)
            scores_fake, feats_fake, feats_real = None, None, None
            if self.train_disc:
                if self.config.diff_samples_for_G_and_D:
                    x_d = batch["input_disc"]
                    y_d = batch["waveform_disc"]
                    with torch.no_grad():
                        y_hat = self.model_g(x_d)
                    if y_hat.shape[1] > 1:
                        y_hat = self.model_g.pqmf_synthesis(y_hat)
                else:
                    x_d = x.clone()
                    y_d = y.clone()
                    y_hat = self.y_hat_g
                if len(signature(self.model_d.forward).parameters) == 2:
                    D_out_fake = self.model_d(y_hat.detach().clone(), x_d)
                    D_out_real = self.model_d(y_d, x_d)
                else:
                    D_out_fake = self.model_d(y_hat.detach())
                    D_out_real = self.model_d(y_d)
                if isinstance(D_out_fake, tuple):
                    scores_fake, feats_fake = D_out_fake
                    if D_out_real is None:
                        scores_real, feats_real = None, None
                    else:
                        scores_real, feats_real = D_out_real
                else:
                    scores_fake = D_out_fake
                    scores_real = D_out_real
                loss_dict = criterion[optimizer_idx](scores_fake, scores_real)
                outputs = {"model_outputs": y_hat}
        if optimizer_idx == 1:
            scores_fake, feats_fake, feats_real = None, None, None
            if self.train_disc:
                if len(signature(self.model_d.forward).parameters) == 2:
                    D_out_fake = self.model_d(self.y_hat_g, x)
                else:
                    D_out_fake = self.model_d(self.y_hat_g)
                D_out_real = None
                if self.config.use_feat_match_loss:
                    with torch.no_grad():
                        D_out_real = self.model_d(y)
                if isinstance(D_out_fake, tuple):
                    scores_fake, feats_fake = D_out_fake
                    if D_out_real is None:
                        feats_real = None
                    else:
                        _, feats_real = D_out_real
                else:
                    scores_fake = D_out_fake
                    feats_fake, feats_real = None, None
            loss_dict = criterion[optimizer_idx](
                self.y_hat_g, y, scores_fake, feats_fake, feats_real, self.y_hat_sub, self.y_sub_g
            )
            outputs = {"model_outputs": self.y_hat_g}
        return outputs, loss_dict

    def _log(self, name: str, ap: AudioProcessor, batch: Dict, outputs: Dict) -> Tuple[Dict, Dict]:
        y_hat = outputs[0]["model_outputs"] if self.train_disc else outputs[1]["model_outputs"]
        y = batch["waveform"]
        figures = plot_results(y_hat, y, ap, name)
        sample_voice = y_hat[0].squeeze(0).detach().cpu().numpy()
        audios = {f"{name}/audio": sample_voice}
        return figures, audios

    def train_log(
        self, batch: Dict, outputs: Dict, logger: "Logger", assets: Dict, steps: int
    ) -> Tuple[Dict, np.ndarray]:
        figures, audios = self._log("eval", self.ap, batch, outputs)
        logger.eval_figures(steps, figures)
        logger.eval_audios(steps, audios, self.ap.sample_rate)

    @torch.no_grad()
    def eval_step(self, batch: Dict, criterion: nn.Module, optimizer_idx: int) -> Tuple[Dict, Dict]:
        self.train_disc = True
        return self.train_step(batch, criterion, optimizer_idx)

    def eval_log(
        self, batch: Dict, outputs: Dict, logger: "Logger", assets: Dict, steps: int
    ) -> Tuple[Dict, np.ndarray]:
        figures, audios = self._log("eval", self.ap, batch, outputs)
        logger.eval_figures(steps, figures)
        logger.eval_audios(steps, audios, self.ap.sample_rate)

    def load_checkpoint(
        self,
        config: Coqpit,
        checkpoint_path: str,
        eval: bool = False,
        cache: bool = False,
    ) -> None:
        state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), cache=cache)
        if "model_disc" in state:
            self.model_g.load_checkpoint(config, checkpoint_path, eval)
        else:
            self.load_state_dict(state["model"])
            if eval:
                self.model_d = None
                if hasattr(self.model_g, "remove_weight_norm"):
                    self.model_g.remove_weight_norm()

    def on_train_step_start(self, trainer) -> None:
        self.train_disc = trainer.total_steps_done >= self.config.steps_to_start_discriminator

    def get_optimizer(self) -> List:
        optimizer1 = get_optimizer(
            self.config.optimizer, self.config.optimizer_params, self.config.lr_gen, self.model_g
        )
        optimizer2 = get_optimizer(
            self.config.optimizer, self.config.optimizer_params, self.config.lr_disc, self.model_d
        )
        return [optimizer2, optimizer1]

    def get_lr(self) -> List:
        return [self.config.lr_disc, self.config.lr_gen]

    def get_scheduler(self, optimizer) -> List:
        scheduler1 = get_scheduler(self.config.lr_scheduler_gen, self.config.lr_scheduler_gen_params, optimizer[0])
        scheduler2 = get_scheduler(self.config.lr_scheduler_disc, self.config.lr_scheduler_disc_params, optimizer[1])
        return [scheduler2, scheduler1]

    @staticmethod
    def format_batch(batch: List) -> Dict:
        if isinstance(batch[0], list):
            x_G, y_G = batch[0]
            x_D, y_D = batch[1]
            return {"input": x_G, "waveform": y_G, "input_disc": x_D, "waveform_disc": y_D}
        x, y = batch
        return {"input": x, "waveform": y}

    def get_data_loader(
        self,
        config: Coqpit,
        assets: Dict,
        is_eval: True,
        samples: List,
        verbose: bool,
        num_gpus: int,
        rank: int = None,
    ):
        dataset = GANDataset(
            ap=self.ap,
            items=samples,
            seq_len=config.seq_len,
            hop_len=self.ap.hop_length,
            pad_short=config.pad_short,
            conv_pad=config.conv_pad,
            return_pairs=config.diff_samples_for_G_and_D if "diff_samples_for_G_and_D" in config else False,
            is_training=not is_eval,
            return_segments=not is_eval,
            use_noise_augment=config.use_noise_augment,
            use_cache=config.use_cache,
            verbose=verbose,
        )
        dataset.shuffle_mapping()
        sampler = DistributedSampler(dataset, shuffle=True) if num_gpus > 1 else None
        loader = DataLoader(
            dataset,
            batch_size=1 if is_eval else config.batch_size,
            shuffle=num_gpus == 0,
            drop_last=False,
            sampler=sampler,
            num_workers=config.num_eval_loader_workers if is_eval else config.num_loader_workers,
            pin_memory=False,
        )
        return loader

    def get_criterion(self):
        return [DiscriminatorLoss(self.config), GeneratorLoss(self.config)]

    @staticmethod
    def init_from_config(config: Coqpit, verbose=True) -> "GAN":
        ap = AudioProcessor.init_from_config(config, verbose=verbose)
        return GAN(config, ap=ap)

In [None]:
import numpy as np
from torch import nn


class GBlock(nn.Module):
    def __init__(self, in_channels, cond_channels, downsample_factor):
        super().__init__()
        self.in_channels = in_channels
        self.cond_channels = cond_channels
        self.downsample_factor = downsample_factor
        self.start = nn.Sequential(
            nn.AvgPool1d(downsample_factor, stride=downsample_factor),
            nn.ReLU(),
            nn.Conv1d(in_channels, in_channels * 2, kernel_size=3, padding=1),
        )
        self.lc_conv1d = nn.Conv1d(cond_channels, in_channels * 2, kernel_size=1)
        self.end = nn.Sequential(
            nn.ReLU(), nn.Conv1d(in_channels * 2, in_channels * 2, kernel_size=3, dilation=2, padding=2)
        )
        self.residual = nn.Sequential(
            nn.Conv1d(in_channels, in_channels * 2, kernel_size=1),
            nn.AvgPool1d(downsample_factor, stride=downsample_factor),
        )

    def forward(self, inputs, conditions):
        outputs = self.start(inputs) + self.lc_conv1d(conditions)
        outputs = self.end(outputs)
        residual_outputs = self.residual(inputs)
        outputs = outputs + residual_outputs

        return outputs


class DBlock(nn.Module):
    def __init__(self, in_channels, out_channels, downsample_factor):
        super().__init__()
        self.in_channels = in_channels
        self.downsample_factor = downsample_factor
        self.out_channels = out_channels
        self.donwsample_layer = nn.AvgPool1d(downsample_factor, stride=downsample_factor)
        self.layers = nn.Sequential(
            nn.ReLU(),
            nn.Conv1d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv1d(out_channels, out_channels, kernel_size=3, dilation=2, padding=2),
        )
        self.residual = nn.Sequential(
            nn.Conv1d(in_channels, out_channels, kernel_size=1),
        )

    def forward(self, inputs):
        if self.downsample_factor > 1:
            outputs = self.layers(self.donwsample_layer(inputs)) + self.donwsample_layer(self.residual(inputs))
        else:
            outputs = self.layers(inputs) + self.residual(inputs)
        return outputs


class ConditionalDiscriminator(nn.Module):
    def __init__(self, in_channels, cond_channels, downsample_factors=(2, 2, 2), out_channels=(128, 256)):
        super().__init__()
        assert len(downsample_factors) == len(out_channels) + 1
        self.in_channels = in_channels
        self.cond_channels = cond_channels
        self.downsample_factors = downsample_factors
        self.out_channels = out_channels
        self.pre_cond_layers = nn.ModuleList()
        self.post_cond_layers = nn.ModuleList()
        self.pre_cond_layers += [DBlock(in_channels, 64, 1)]
        in_channels = 64
        for i, channel in enumerate(out_channels):
            self.pre_cond_layers.append(DBlock(in_channels, channel, downsample_factors[i]))
            in_channels = channel
        self.cond_block = GBlock(in_channels, cond_channels, downsample_factors[-1])
        self.post_cond_layers += [
            DBlock(in_channels * 2, in_channels * 2, 1),
            DBlock(in_channels * 2, in_channels * 2, 1),
            nn.AdaptiveAvgPool1d(1),
            nn.Conv1d(in_channels * 2, 1, kernel_size=1),
        ]

    def forward(self, inputs, conditions):
        batch_size = inputs.size()[0]
        outputs = inputs.view(batch_size, self.in_channels, -1)
        for layer in self.pre_cond_layers:
            outputs = layer(outputs)
        outputs = self.cond_block(outputs, conditions)
        for layer in self.post_cond_layers:
            outputs = layer(outputs)
        return outputs


class UnconditionalDiscriminator(nn.Module):
    def __init__(self, in_channels, base_channels=64, downsample_factors=(8, 4), out_channels=(128, 256)):
        super().__init__()
        self.downsample_factors = downsample_factors
        self.in_channels = in_channels
        self.downsample_factors = downsample_factors
        self.out_channels = out_channels
        self.layers = nn.ModuleList()
        self.layers += [DBlock(self.in_channels, base_channels, 1)]
        in_channels = base_channels
        for i, factor in enumerate(downsample_factors):
            self.layers.append(DBlock(in_channels, out_channels[i], factor))
            in_channels *= 2
        self.layers += [
            DBlock(in_channels, in_channels, 1),
            DBlock(in_channels, in_channels, 1),
            nn.AdaptiveAvgPool1d(1),
            nn.Conv1d(in_channels, 1, kernel_size=1),
        ]

    def forward(self, inputs):
        batch_size = inputs.size()[0]
        outputs = inputs.view(batch_size, self.in_channels, -1)
        for layer in self.layers:
            outputs = layer(outputs)
        return outputs


class RandomWindowDiscriminator(nn.Module):
    def __init__(
        self,
        cond_channels,
        hop_length,
        uncond_disc_donwsample_factors=(8, 4),
        cond_disc_downsample_factors=((8, 4, 2, 2, 2), (8, 4, 2, 2), (8, 4, 2), (8, 4), (4, 2, 2)),
        cond_disc_out_channels=((128, 128, 256, 256), (128, 256, 256), (128, 256), (256,), (128, 256)),
        window_sizes=(512, 1024, 2048, 4096, 8192),
    ):
        super().__init__()
        self.cond_channels = cond_channels
        self.window_sizes = window_sizes
        self.hop_length = hop_length
        self.base_window_size = self.hop_length * 2
        self.ks = [ws // self.base_window_size for ws in window_sizes]
        assert len(cond_disc_downsample_factors) == len(cond_disc_out_channels) == len(window_sizes)
        for ws in window_sizes:
            assert ws % hop_length == 0
        for idx, cf in enumerate(cond_disc_downsample_factors):
            assert np.prod(cf) == hop_length // self.ks[idx]
        self.unconditional_discriminators = nn.ModuleList([])
        for k in self.ks:
            layer = UnconditionalDiscriminator(
                in_channels=k, base_channels=64, downsample_factors=uncond_disc_donwsample_factors
            )
            self.unconditional_discriminators.append(layer)
        self.conditional_discriminators = nn.ModuleList([])
        for idx, k in enumerate(self.ks):
            layer = ConditionalDiscriminator(
                in_channels=k,
                cond_channels=cond_channels,
                downsample_factors=cond_disc_downsample_factors[idx],
                out_channels=cond_disc_out_channels[idx],
            )
            self.conditional_discriminators.append(layer)

    def forward(self, x, c):
        scores = []
        feats = []
        for window_size, layer in zip(self.window_sizes, self.unconditional_discriminators):
            index = np.random.randint(x.shape[-1] - window_size)
            score = layer(x[:, :, index : index + window_size])
            scores.append(score)
        for window_size, layer in zip(self.window_sizes, self.conditional_discriminators):
            frame_size = window_size // self.hop_length
            lc_index = np.random.randint(c.shape[-1] - frame_size)
            sample_index = lc_index * self.hop_length
            x_sub = x[:, :, sample_index : (lc_index + frame_size) * self.hop_length]
            c_sub = c[:, :, lc_index : lc_index + frame_size]
            score = layer(x_sub, c_sub)
            scores.append(score)
        return scores, feats

In [None]:
from torch import nn
from torch.nn.utils.parametrizations import weight_norm
from torch.nn.utils.parametrize import remove_parametrizations

class ResidualStack(nn.Module):
    def __init__(self, channels, num_res_blocks, kernel_size):
        super().__init__()

        assert (kernel_size - 1) % 2 == 0, " [!] kernel_size has to be odd."
        base_padding = (kernel_size - 1) // 2

        self.blocks = nn.ModuleList()
        for idx in range(num_res_blocks):
            layer_kernel_size = kernel_size
            layer_dilation = layer_kernel_size**idx
            layer_padding = base_padding * layer_dilation
            self.blocks += [
                nn.Sequential(
                    nn.LeakyReLU(0.2),
                    nn.ReflectionPad1d(layer_padding),
                    weight_norm(
                        nn.Conv1d(channels, channels, kernel_size=kernel_size, dilation=layer_dilation, bias=True)
                    ),
                    nn.LeakyReLU(0.2),
                    weight_norm(nn.Conv1d(channels, channels, kernel_size=1, bias=True)),
                )
            ]

        self.shortcuts = nn.ModuleList(
            [weight_norm(nn.Conv1d(channels, channels, kernel_size=1, bias=True)) for _ in range(num_res_blocks)]
        )

    def forward(self, x):
        for block, shortcut in zip(self.blocks, self.shortcuts):
            x = shortcut(x) + block(x)
        return x

    def remove_weight_norm(self):
        for block, shortcut in zip(self.blocks, self.shortcuts):
            remove_parametrizations(block[2], "weight")
            remove_parametrizations(block[4], "weight")
            remove_parametrizations(shortcut, "weight")

import math
import torch
from torch import nn
from torch.nn import functional as F

class GroupNorm32(nn.GroupNorm):
    def forward(self, x):
        return super().forward(x.float()).type(x.dtype)

def conv_nd(dims, *args, **kwargs):
    if dims == 1:
        return nn.Conv1d(*args, **kwargs)
    elif dims == 2:
        return nn.Conv2d(*args, **kwargs)
    elif dims == 3:
        return nn.Conv3d(*args, **kwargs)
    raise ValueError(f"unsupported dimensions: {dims}")

def normalization(channels):
    groups = 32
    if channels <= 16:
        groups = 8
    elif channels <= 64:
        groups = 16
    while channels % groups != 0:
        groups = int(groups / 2)
    assert groups > 2
    return GroupNorm32(groups, channels)

def zero_module(module):
    for p in module.parameters():
        p.detach().zero_()
    return module

class QKVAttention(nn.Module):
    def __init__(self, n_heads):
        super().__init__()
        self.n_heads = n_heads

    def forward(self, qkv, mask=None, qk_bias=0):
        bs, width, length = qkv.shape
        assert width % (3 * self.n_heads) == 0
        ch = width // (3 * self.n_heads)
        q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
        scale = 1 / math.sqrt(math.sqrt(ch))
        weight = torch.einsum("bct,bcs->bts", q * scale, k * scale)
        weight = weight + qk_bias
        if mask is not None:
            mask = mask.repeat(self.n_heads, 1, 1)
            weight[mask.logical_not()] = -torch.inf
        weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
        a = torch.einsum("bts,bcs->bct", weight, v)
        return a.reshape(bs, -1, length)

class AttentionBlock(nn.Module):
    def __init__(
        self,
        channels,
        num_heads=1,
        num_head_channels=-1,
        out_channels=None,
        do_activation=False,
    ):
        super().__init__()
        self.channels = channels
        out_channels = channels if out_channels is None else out_channels
        self.do_activation = do_activation
        if num_head_channels == -1:
            self.num_heads = num_heads
        else:
            assert (
                channels % num_head_channels == 0
            ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
            self.num_heads = channels // num_head_channels
        self.norm = normalization(channels)
        self.qkv = conv_nd(1, channels, out_channels * 3, 1)
        self.attention = QKVAttention(self.num_heads)
        self.x_proj = nn.Identity() if out_channels == channels else conv_nd(1, channels, out_channels, 1)
        self.proj_out = zero_module(conv_nd(1, out_channels, out_channels, 1))

    def forward(self, x, mask=None, qk_bias=0):
        b, c, *spatial = x.shape
        if mask is not None:
            if len(mask.shape) == 2:
                mask = mask.unsqueeze(0).repeat(x.shape[0], 1, 1)
            if mask.shape[1] != x.shape[-1]:
                mask = mask[:, : x.shape[-1], : x.shape[-1]]
        x = x.reshape(b, c, -1)
        x = self.norm(x)
        if self.do_activation:
            x = F.silu(x, inplace=True)
        qkv = self.qkv(x)
        h = self.attention(qkv, mask=mask, qk_bias=qk_bias)
        h = self.proj_out(h)
        xp = self.x_proj(x)
        return (xp + h).reshape(b, xp.shape[1], *spatial)

class ConditioningEncoder(nn.Module):
    def __init__(
        self,
        spec_dim,
        embedding_dim,
        attn_blocks=6,
        num_attn_heads=4,
    ):
        super().__init__()
        attn = []
        self.init = nn.Conv1d(spec_dim, embedding_dim, kernel_size=1)
        for a in range(attn_blocks):
            attn.append(AttentionBlock(embedding_dim, num_attn_heads))
        self.attn = nn.Sequential(*attn)
        self.dim = embedding_dim

    def forward(self, x):
        h = self.init(x)
        h = self.attn(h)
        return h

In [None]:
import os
from dataclasses import dataclass

import librosa
import torch
import torch.nn.functional as F
import torchaudio
from coqpit import Coqpit

from TTS.tts.layers.vcvc.nlp import NLP
from TTS.tts.layers.vcvc.hifigan_decoder import HifiDecoder
from TTS.tts.layers.vcvc.stream_generator import init_stream_support
from TTS.tts.layers.vcvc.tokenizer import VoiceBpeTokenizer, split_sentence
from TTS.tts.layers.vcvc.vcvc_manager import SpeakerManager, LanguageManager
from TTS.tts.models.base_tts import BaseTTS
from TTS.utils.io import load_fsspec

init_stream_support()

def wav_to_mel_cloning(
    wav,
    mel_norms_file="...pth",
    mel_norms=None,
    device=torch.device("cpu"),
    n_fft=4096,
    hop_length=1024,
    win_length=4096,
    power=2,
    normalized=False,
    sample_rate=22050,
    f_min=0,
    f_max=8000,
    n_mels=80,
):
    mel_stft = torchaudio.transforms.MelSpectrogram(
        n_fft=n_fft,
        hop_length=hop_length,
        win_length=win_length,
        power=power,
        normalized=normalized,
        sample_rate=sample_rate,
        f_min=f_min,
        f_max=f_max,
        n_mels=n_mels,
        norm="slaney",
    ).to(device)
    wav = wav.to(device)
    mel = mel_stft(wav)
    mel = torch.log(torch.clamp(mel, min=1e-5))
    if mel_norms is None:
        mel_norms = torch.load(mel_norms_file, map_location=device)
    mel = mel / mel_norms.unsqueeze(0).unsqueeze(-1)
    return mel

def load_audio(audiopath, sampling_rate):
    audio, lsr = torchaudio.load(audiopath)
    if audio.size(0) != 1:
        audio = torch.mean(audio, dim=0, keepdim=True)

    if lsr != sampling_rate:
        audio = torchaudio.functional.resample(audio, lsr, sampling_rate)
    if torch.any(audio > 10) or not torch.any(audio < 0):
        print(f"Error with {audiopath}. Max={audio.max()} min={audio.min()}")
    audio.clip_(-1, 1)
    return audio

def pad_or_truncate(t, length):
    tp = t[..., :length]
    if t.shape[-1] == length:
        tp = t
    elif t.shape[-1] < length:
        tp = F.pad(t, (0, length - t.shape[-1]))
    return tp

@dataclass
class VcvcAudioConfig(Coqpit):
    sample_rate: int = 22050
    output_sample_rate: int = 24000

@dataclass
class VcvcArgs(Coqpit):
    nlp_batch_size: int = 1
    enable_redaction: bool = False
    kv_cache: bool = True
    nlp_checkpoint: str = None
    clvp_checkpoint: str = None
    decoder_checkpoint: str = None
    num_chars: int = 255
    tokenizer_file: str = ""
    nlp_max_audio_tokens: int = 605
    nlp_max_text_tokens: int = 402
    nlp_max_prompt_tokens: int = 70
    nlp_layers: int = 30
    nlp_n_model_channels: int = 1024
    nlp_n_heads: int = 16
    nlp_number_text_tokens: int = None
    nlp_start_text_token: int = None
    nlp_stop_text_token: int = None
    nlp_num_audio_tokens: int = 8194
    nlp_start_audio_token: int = 8192
    nlp_stop_audio_token: int = 8193
    nlp_code_stride_len: int = 1024
    nlp_use_masking_gt_prompt_approach: bool = True
    nlp_use_perceiver_resampler: bool = False

    input_sample_rate: int = 22050
    output_sample_rate: int = 24000
    output_hop_length: int = 256
    decoder_input_dim: int = 1024
    d_vector_dim: int = 512
    cond_d_vector_in_each_upsampling_layer: bool = True

    duration_const: int = 102400

class Vcvc(BaseTTS):
    def __init__(self, config: Coqpit):
        super().__init__(config, ap=None, tokenizer=None)
        self.mel_stats_path = None
        self.config = config
        self.nlp_checkpoint = self.args.nlp_checkpoint
        self.decoder_checkpoint = self.args.decoder_checkpoint
        self.models_dir = config.model_dir
        self.nlp_batch_size = self.args.nlp_batch_size

        self.tokenizer = VoiceBpeTokenizer()
        self.nlp = None
        self.init_models()
        self.register_buffer("mel_stats", torch.ones(80))

    def init_models(self):
        if self.tokenizer.tokenizer is not None:
            self.args.nlp_number_text_tokens = self.tokenizer.get_number_tokens()
            self.args.nlp_start_text_token = self.tokenizer.tokenizer.token_to_id("[START]")
            self.args.nlp_stop_text_token = self.tokenizer.tokenizer.token_to_id("[STOP]")

        if self.args.nlp_number_text_tokens:
            self.nlp = NLP(
                layers=self.args.nlp_layers,
                model_dim=self.args.nlp_n_model_channels,
                start_text_token=self.args.nlp_start_text_token,
                stop_text_token=self.args.nlp_stop_text_token,
                heads=self.args.nlp_n_heads,
                max_text_tokens=self.args.nlp_max_text_tokens,
                max_mel_tokens=self.args.nlp_max_audio_tokens,
                max_prompt_tokens=self.args.nlp_max_prompt_tokens,
                number_text_tokens=self.args.nlp_number_text_tokens,
                num_audio_tokens=self.args.nlp_num_audio_tokens,
                start_audio_token=self.args.nlp_start_audio_token,
                stop_audio_token=self.args.nlp_stop_audio_token,
                use_perceiver_resampler=self.args.nlp_use_perceiver_resampler,
                code_stride_len=self.args.nlp_code_stride_len,
            )

        self.hifigan_decoder = HifiDecoder(
            input_sample_rate=self.args.input_sample_rate,
            output_sample_rate=self.args.output_sample_rate,
            output_hop_length=self.args.output_hop_length,
            ar_mel_length_compression=self.args.nlp_code_stride_len,
            decoder_input_dim=self.args.decoder_input_dim,
            d_vector_dim=self.args.d_vector_dim,
            cond_d_vector_in_each_upsampling_layer=self.args.cond_d_vector_in_each_upsampling_layer,
        )

    @property
    def device(self):
        return next(self.parameters()).device

    @torch.inference_mode()
    def get_nlp_cond_latents(self, audio, sr, length: int = 30, chunk_length: int = 6):
        if sr != 22050:
            audio = torchaudio.functional.resample(audio, sr, 22050)
        if length > 0:
            audio = audio[:, : 22050 * length]
        if self.args.nlp_use_perceiver_resampler:
            style_embs = []
            for i in range(0, audio.shape[1], 22050 * chunk_length):
                audio_chunk = audio[:, i : i + 22050 * chunk_length]

                if audio_chunk.size(-1) < 22050 * 0.33:
                    continue

                mel_chunk = wav_to_mel_cloning(
                    audio_chunk,
                    mel_norms=self.mel_stats.cpu(),
                    n_fft=2048,
                    hop_length=256,
                    win_length=1024,
                    power=2,
                    normalized=False,
                    sample_rate=22050,
                    f_min=0,
                    f_max=8000,
                    n_mels=80,
                )
                style_emb = self.nlp.get_style_emb(mel_chunk.to(self.device), None)
                style_embs.append(style_emb)

            cond_latent = torch.stack(style_embs).mean(dim=0)
        else:
            mel = wav_to_mel_cloning(
                audio,
                mel_norms=self.mel_stats.cpu(),
                n_fft=4096,
                hop_length=1024,
                win_length=4096,
                power=2,
                normalized=False,
                sample_rate=22050,
                f_min=0,
                f_max=8000,
                n_mels=80,
            )
            cond_latent = self.nlp.get_style_emb(mel.to(self.device))
        return cond_latent.transpose(1, 2)

    @torch.inference_mode()
    def get_speaker_embedding(self, audio, sr):
        audio_16k = torchaudio.functional.resample(audio, sr, 16000)
        return (
            self.hifigan_decoder.speaker_encoder.forward(audio_16k.to(self.device), l2_norm=True)
            .unsqueeze(-1)
            .to(self.device)
        )

    @torch.inference_mode()
    def get_conditioning_latents(
        self,
        audio_path,
        max_ref_length=30,
        nlp_cond_len=6,
        nlp_cond_chunk_len=6,
        librosa_trim_db=None,
        sound_norm_refs=False,
        load_sr=22050,
    ):
        if not isinstance(audio_path, list):
            audio_paths = [audio_path]
        else:
            audio_paths = audio_path

        speaker_embeddings = []
        audios = []
        speaker_embedding = None
        for file_path in audio_paths:
            audio = load_audio(file_path, load_sr)
            audio = audio[:, : load_sr * max_ref_length].to(self.device)
            if sound_norm_refs:
                audio = (audio / torch.abs(audio).max()) * 0.75
            if librosa_trim_db is not None:
                audio = librosa.effects.trim(audio, top_db=librosa_trim_db)[0]

            speaker_embedding = self.get_speaker_embedding(audio, load_sr)
            speaker_embeddings.append(speaker_embedding)

            audios.append(audio)

        full_audio = torch.cat(audios, dim=-1)
        nlp_cond_latents = self.get_nlp_cond_latents(
            full_audio, load_sr, length=nlp_cond_len, chunk_length=nlp_cond_chunk_len
        )

        if speaker_embeddings:
            speaker_embedding = torch.stack(speaker_embeddings)
            speaker_embedding = speaker_embedding.mean(dim=0)

        return nlp_cond_latents, speaker_embedding

    def forward(self):
        raise NotImplementedError

    def eval_step(self):
        raise NotImplementedError