In [29]:
import hashlib
import os
import tarfile
import wget
import yaml
from pathlib import Path
from typing import Optional, Tuple, Union

import torch
import torch.nn as nn


def resolve_diarization_cache_dir() -> Path:
    """
    Utility method to get the cache directory for the diarization module.

    Returns:
        Path: The path to the cache directory.
    """
    path = Path.joinpath(Path.home(), f".cache/torch/diarization")

    return path

In [31]:
# This code is from NVIDIA NeMo toolkit package `FilterbankFeaturesTA` class:
# https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/asr/parts/preprocessing/features.py#L469
class MelSpectrogramPreprocessor(nn.Module):
    """Mel Spectrogram extraction."""

    def __init__(
        self,
        sample_rate: int = 16000,
        window_size: float = 0.025,
        window_stride: float = 0.01,
        window: str = "hann",
        normalize: str = "per_feature",
        n_fft: int = None,
        preemph: float = 0.97,
        features: int = 64,
        lowfreq: int = 0,
        highfreq: int = None,
        log=True,
        log_zero_guard_type="add",
        log_zero_guard_value=2 ** -24,
        dither=1e-5,
        pad_to=16,
        max_duration=16.7,
        frame_splicing=1,
        exact_pad=False,
        pad_value=0,
        mag_power=2.0,
        use_grads=False,
        rng=None,
        nb_augmentation_prob=0.0,
        nb_max_freq=4000,
    ):
        super().__init__()

        if exact_pad and n_window_stride % 2 == 1:
            raise NotImplementedError(
                f"{self} received exact_pad == True, but hop_size was odd. If audio_length % hop_size == 0. Then the "
                "returned spectrogram would not be of length audio_length // hop_size. Please use an even hop_size."
            )

        self.log_zero_guard_value = log_zero_guard_value

        self.win_length = int(window_size * sample_rate)
        self.hop_length = int(window_stride * sample_rate)
        self.n_fft = n_fft or 2 ** math.ceil(math.log2(self.win_length))
        self.stft_pad_amount = (self.n_fft - self.hop_length) // 2 if exact_pad else None

        torch_windows = {
            "hann": torch.hann_window,
            "hamming": torch.hamming_window,
            "blackman": torch.blackman_window,
            "bartlett": torch.bartlett_window,
            "none": None,
        }
        window_fn = torch_windows.get(window, None)
        window_tensor = window_fn(self.win_length, periodic=False) if window_fn else None
        self.register_buffer("window", window_tensor)
        self.stft = lambda x: torch.stft(
            x,
            n_fft=self.n_fft,
            hop_length=self.hop_length,
            win_length=self.win_length,
            center=False if exact_pad else True,
            window=self.window.to(dtype=torch.float),
            return_complex=True,
        )

        self.normalize = normalize
        self.log = log
        self.dither = dither
        self.frame_splicing = frame_splicing
        self.nfilt = nfilt
        self.preemph = preemph
        self.pad_to = pad_to
        highfreq = highfreq or sample_rate / 2

        filterbanks = torch.tensor(
            librosa.filters.mel(sr=sample_rate, n_fft=self.n_fft, n_mels=nfilt, fmin=lowfreq, fmax=highfreq),
            dtype=torch.float,
        ).unsqueeze(0)
        self.register_buffer("fb", filterbanks)

        # Calculate maximum sequence length
        max_length = self.get_seq_len(torch.tensor(max_duration * sample_rate, dtype=torch.float))
        max_pad = pad_to - (max_length % pad_to) if pad_to > 0 else 0
        self.max_length = max_length + max_pad
        self.pad_value = pad_value
        self.mag_power = mag_power

        # We want to avoid taking the log of zero
        # There are two options: either adding or clamping to a small value
        if log_zero_guard_type not in ["add", "clamp"]:
            raise ValueError(
                f"{self} received {log_zero_guard_type} for the "
                f"log_zero_guard_type parameter. It must be either 'add' or "
                f"'clamp'."
            )

        self.use_grads = use_grads
        if not use_grads:
            self.forward = torch.no_grad()(self.forward)
        self._rng = random.Random() if rng is None else rng
        self.nb_augmentation_prob = nb_augmentation_prob
        if self.nb_augmentation_prob > 0.0:
            if nb_max_freq >= sample_rate / 2:
                self.nb_augmentation_prob = 0.0
            else:
                self._nb_max_fft_bin = int((nb_max_freq / sample_rate) * n_fft)

        self.log_zero_guard_type = log_zero_guard_type

    def log_zero_guard_value_fn(self, x):
        if isinstance(self.log_zero_guard_value, str):
            if self.log_zero_guard_value == "tiny":
                return torch.finfo(x.dtype).tiny
            elif self.log_zero_guard_value == "eps":
                return torch.finfo(x.dtype).eps
            else:
                raise ValueError(
                    f"{self} received {self.log_zero_guard_value} for the "
                    f"log_zero_guard_type parameter. It must be either a "
                    f"number, 'tiny', or 'eps'"
                )
        else:
            return self.log_zero_guard_value

    def get_seq_len(self, seq_len):
        # Assuming that center is True is stft_pad_amount = 0
        pad_amount = self.stft_pad_amount * 2 if self.stft_pad_amount is not None else self.n_fft // 2 * 2
        seq_len = torch.floor((seq_len + pad_amount - self.n_fft) / self.hop_length) + 1
        return seq_len.to(dtype=torch.long)

    @property
    def filter_banks(self):
        return self.fb

    def forward(self, x, seq_len, linear_spec=False):
        seq_len = self.get_seq_len(seq_len.float())

        if self.stft_pad_amount is not None:
            x = torch.nn.functional.pad(
                x.unsqueeze(1), (self.stft_pad_amount, self.stft_pad_amount), "reflect"
            ).squeeze(1)

        # do preemphasis
        if self.preemph is not None:
            x = torch.cat((x[:, 0].unsqueeze(1), x[:, 1:] - self.preemph * x[:, :-1]), dim=1)

        # disable autocast to get full range of stft values
        with torch.cuda.amp.autocast(enabled=False):
            x = self.stft(x)

        # torch stft returns complex tensor (of shape [B,N,T]); so convert to magnitude
        # guard is needed for sqrt if grads are passed through
        guard = 0 if not self.use_grads else 1e-5
        x = torch.view_as_real(x)
        x = torch.sqrt(x.pow(2).sum(-1) + guard)

        # get power spectrum
        if self.mag_power != 1.0:
            x = x.pow(self.mag_power)

        # return plain spectrogram if required
        if linear_spec:
            return x, seq_len

        # dot with filterbank energies
        x = torch.matmul(self.fb.to(x.dtype), x)
        # log features if required
        if self.log:
            if self.log_zero_guard_type == "add":
                x = torch.log(x + self.log_zero_guard_value_fn(x))
            elif self.log_zero_guard_type == "clamp":
                x = torch.log(torch.clamp(x, min=self.log_zero_guard_value_fn(x)))
            else:
                raise ValueError("log_zero_guard_type was not understood")

        # frame splicing if required
        if self.frame_splicing > 1:
            x = splice_frames(x, self.frame_splicing)

        # normalize if required
        if self.normalize:
            x, _, _ = normalize_batch(x, seq_len, normalize_type=self.normalize)

        # mask to zero any values beyond seq_len in batch, pad to multiple of `pad_to` (for efficiency)
        max_len = x.size(-1)
        mask = torch.arange(max_len).to(x.device)
        mask = mask.repeat(x.size(0), 1) >= seq_len.unsqueeze(1)
        x = x.masked_fill(mask.unsqueeze(1).type(torch.bool).to(device=x.device), self.pad_value)
        del mask
        pad_to = self.pad_to

        if pad_to == "max":
            x = nn.functional.pad(x, (0, self.max_length - x.size(-1)), value=self.pad_value)
        elif pad_to > 0:
            pad_amt = x.size(-1) % pad_to
            if pad_amt != 0:
                x = nn.functional.pad(x, (0, pad_to - pad_amt), value=self.pad_value)
    
        return x, seq_len

In [9]:
# Inspired from NVIDIA NeMo's EncDecSpeakerLabelModel
# https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/asr/models/label_models.py#L67
class EncDecSpeakerLabelModel:
    """The EncDecSpeakerLabelModel class encapsulates the encoder-decoder speaker label model."""

    def __init__(self, model_name: str = "titanet_large") -> None:
        """Initialize the EncDecSpeakerLabelModel class.

        The EncDecSpeakerLabelModel class encapsulates the encoder-decoder speaker label model.
        Only the "titanet_large" model is supported at the moment.
        For more models: https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/asr/models/label_models.py#L59

        Args:
            model_name (str, optional): The name of the model to use. Defaults to "titanet_large".

        Raises:
            ValueError: If the model name is not supported.
        """
        if model_name != "titanet_large":
            raise ValueError(
                f"Unknown model name: {model_name}. Only 'titanet_large' is supported at the moment."
            )

        self.model_name = model_name
        self.location_in_the_cloud = "https://api.ngc.nvidia.com/v2/models/nvidia/nemo/titanet_large/versions/v1/files/titanet-l.nemo"
        self.cache_dir = Path.joinpath(resolve_diarization_cache_dir(), "titanet-l")
        cache_subfolder = hashlib.md5((self.location_in_the_cloud).encode("utf-8")).hexdigest()

        self.nemo_model_folder, self.nemo_model_file = self.download_model_if_required(
            url=self.location_in_the_cloud, cache_dir=self.cache_dir, subfolder=cache_subfolder,
        )

        self.model_files = Path.joinpath(self.nemo_model_folder, "model_files")
        if not self.model_files.exists():
            self.model_files.mkdir(parents=True, exist_ok=True)
            self.unpack_nemo_file(self.nemo_model_file, self.model_files)

        model_weights_file_path = Path.joinpath(self.model_files, "model_weights.ckpt")
        model_config_file_path = Path.joinpath(self.model_files, "model_config.yaml")
        with open(model_config_file_path, "r") as config_file:
            self.model_config = yaml.safe_load(config_file)
        
        self.preprocessor = EncDecSpeakerLabelModel.from_config_dict(cfg.preprocessor)
        self.encoder = EncDecSpeakerLabelModel.from_config_dict(cfg.encoder)
        self.decoder = EncDecSpeakerLabelModel.from_config_dict(cfg.decoder)


    @staticmethod
    def download_model_if_required(url, subfolder=None, cache_dir=None) -> Tuple[str, str]:
        """
        Helper function to download pre-trained weights from the cloud.

        Args:
            url: (str) URL to download from.
            cache_dir: (str) a cache directory where to download. If not present, this function will attempt to create it.
                If None (default), then it will be $HOME/.cache/torch/diarization
            subfolder: (str) subfolder within cache_dir. The file will be stored in cache_dir/subfolder. Subfolder can
                be empty

        Returns:
            Tuple[str, str]: cache_dir and filepath to the downloaded file.
        """
        destination = Path.joinpath(cache_dir, subfolder)

        if not destination.exists():
            destination.mkdir(parents=True, exist_ok=True)

        filename = url.split("/")[-1]
        destination_file = Path.joinpath(destination, filename)

        if destination_file.exists():
            return destination, destination_file

        i = 0
        while i < 10:  # try 10 times
            i += 1

            try:
                wget.download(url, str(destination_file))
                if os.path.exists(destination_file):
                    return destination, destination_file

            except:
                continue

        raise ValueError("Not able to download the diarization model, please try again later.")
    
    @staticmethod
    def unpack_nemo_file(filepath: Path, out_folder: Path) -> str:
        """
        Unpacks a .nemo file into a folder.

        Args:
            filepath (Path): path to the .nemo file (can be compressed or uncompressed)
            out_folder (Path): path to the folder where the .nemo file should be unpacked

        Returns:
            path to the unpacked folder
        """
        try:
            tar = tarfile.open(filepath, "r:")  # try uncompressed
        except tarfile.ReadError:
            tar = tarfile.open(filepath, "r:gz")  # try compressed
        finally:
            tar.extractall(path=out_folder)
            tar.close()


In [14]:
from nemo.core.classes.common import Serialization
from omegaconf import DictConfig, OmegaConf

In [21]:
model = EncDecSpeakerLabelModel()

In [24]:
model.model_config["preprocessor"]

{'_target_': 'nemo.collections.asr.modules.AudioToMelSpectrogramPreprocessor',
 'normalize': 'per_feature',
 'window_size': 0.025,
 'sample_rate': 16000,
 'window_stride': 0.01,
 'window': 'hann',
 'features': 80,
 'n_fft': 512,
 'frame_splicing': 1,
 'dither': 1e-05}