In [1]:
import hashlib
import os
import tarfile
import wget
import yaml
from pathlib import Path
from typing import Optional, Tuple, Union

import torch
import torch.nn as nn


def resolve_diarization_cache_dir() -> Path:
    """
    Utility method to get the cache directory for the diarization module.

    Returns:
        Path: The path to the cache directory.
    """
    path = Path.joinpath(Path.home(), f".cache/torch/diarization")

    return path

In [15]:
# Inspired from NVIDIA NeMo's EncDecSpeakerLabelModel
# https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/asr/models/label_models.py#L67
class EncDecSpeakerLabelModel:
    """The EncDecSpeakerLabelModel class encapsulates the encoder-decoder speaker label model."""

    def __init__(self, model_name: str = "titanet_large") -> None:
        """Initialize the EncDecSpeakerLabelModel class.

        The EncDecSpeakerLabelModel class encapsulates the encoder-decoder speaker label model.
        Only the "titanet_large" model is supported at the moment.
        For more models: https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/asr/models/label_models.py#L59

        Args:
            model_name (str, optional): The name of the model to use. Defaults to "titanet_large".

        Raises:
            ValueError: If the model name is not supported.
        """
        if model_name != "titanet_large":
            raise ValueError(
                f"Unknown model name: {model_name}. Only 'titanet_large' is supported at the moment."
            )

        self.model_name = model_name
        self.location_in_the_cloud = "https://api.ngc.nvidia.com/v2/models/nvidia/nemo/titanet_large/versions/v1/files/titanet-l.nemo"
        self.cache_dir = Path.joinpath(resolve_diarization_cache_dir(), "titanet-l")
        cache_subfolder = hashlib.md5((self.location_in_the_cloud).encode("utf-8")).hexdigest()

        self.nemo_model_folder, self.nemo_model_file = self.download_model_if_required(
            url=self.location_in_the_cloud, cache_dir=self.cache_dir, subfolder=cache_subfolder,
        )

        self.model_files = Path.joinpath(self.nemo_model_folder, "model_files")
        if not self.model_files.exists():
            self.model_files.mkdir(parents=True, exist_ok=True)
            self.unpack_nemo_file(self.nemo_model_file, self.model_files)

        self.model_weights_file_path = Path.joinpath(self.model_files, "model_weights.ckpt")
        model_config_file_path = Path.joinpath(self.model_files, "model_config.yaml")
        with open(model_config_file_path, "r") as config_file:
            self.model_config = yaml.safe_load(config_file)


    @staticmethod
    def download_model_if_required(url, subfolder=None, cache_dir=None) -> Tuple[str, str]:
        """
        Helper function to download pre-trained weights from the cloud.

        Args:
            url: (str) URL to download from.
            cache_dir: (str) a cache directory where to download. If not present, this function will attempt to create it.
                If None (default), then it will be $HOME/.cache/torch/diarization
            subfolder: (str) subfolder within cache_dir. The file will be stored in cache_dir/subfolder. Subfolder can
                be empty

        Returns:
            Tuple[str, str]: cache_dir and filepath to the downloaded file.
        """
        destination = Path.joinpath(cache_dir, subfolder)

        if not destination.exists():
            destination.mkdir(parents=True, exist_ok=True)

        filename = url.split("/")[-1]
        destination_file = Path.joinpath(destination, filename)

        if destination_file.exists():
            return destination, destination_file

        i = 0
        while i < 10:  # try 10 times
            i += 1

            try:
                wget.download(url, str(destination_file))
                if os.path.exists(destination_file):
                    return destination, destination_file

            except:
                continue

        raise ValueError("Not able to download the diarization model, please try again later.")
    
    @staticmethod
    def unpack_nemo_file(filepath: Path, out_folder: Path) -> str:
        """
        Unpacks a .nemo file into a folder.

        Args:
            filepath (Path): path to the .nemo file (can be compressed or uncompressed)
            out_folder (Path): path to the folder where the .nemo file should be unpacked

        Returns:
            path to the unpacked folder
        """
        try:
            tar = tarfile.open(filepath, "r:")  # try uncompressed
        except tarfile.ReadError:
            tar = tarfile.open(filepath, "r:gz")  # try compressed
        finally:
            tar.extractall(path=out_folder)
            tar.close()


In [16]:
model = EncDecSpeakerLabelModel()

In [18]:
model.model_weights_file_path
ckpt = torch.load(model.model_weights_file_path, map_location="cpu")

In [23]:
print(ckpt.keys())

odict_keys(['preprocessor.featurizer.window', 'preprocessor.featurizer.fb', 'encoder.encoder.0.mconv.0.conv.weight', 'encoder.encoder.0.mconv.1.conv.weight', 'encoder.encoder.0.mconv.2.weight', 'encoder.encoder.0.mconv.2.bias', 'encoder.encoder.0.mconv.2.running_mean', 'encoder.encoder.0.mconv.2.running_var', 'encoder.encoder.0.mconv.2.num_batches_tracked', 'encoder.encoder.0.mconv.3.fc.0.weight', 'encoder.encoder.0.mconv.3.fc.2.weight', 'encoder.encoder.1.mconv.0.conv.weight', 'encoder.encoder.1.mconv.1.conv.weight', 'encoder.encoder.1.mconv.2.weight', 'encoder.encoder.1.mconv.2.bias', 'encoder.encoder.1.mconv.2.running_mean', 'encoder.encoder.1.mconv.2.running_var', 'encoder.encoder.1.mconv.2.num_batches_tracked', 'encoder.encoder.1.mconv.5.conv.weight', 'encoder.encoder.1.mconv.6.conv.weight', 'encoder.encoder.1.mconv.7.weight', 'encoder.encoder.1.mconv.7.bias', 'encoder.encoder.1.mconv.7.running_mean', 'encoder.encoder.1.mconv.7.running_var', 'encoder.encoder.1.mconv.7.num_batches_

In [4]:
encoder = model.model_config["encoder"]
decoder = model.model_config["decoder"]
preprocessor = model.model_config["preprocessor"]

In [10]:
for k, v in encoder.items():
    if isinstance(v, list):
        print(f"{k}:")
        for value in v:
            print(f"{value}")
    else:
        print(f"{k}: {v}")

_target_: nemo.collections.asr.modules.ConvASREncoder
feat_in: 80
activation: relu
conv_mask: True
jasper:
{'filters': 1024, 'repeat': 1, 'kernel': [3], 'stride': [1], 'dilation': [1], 'dropout': 0.0, 'residual': False, 'separable': True, 'se': True, 'se_context_size': -1}
{'filters': 1024, 'repeat': 3, 'kernel': [7], 'stride': [1], 'dilation': [1], 'dropout': 0.1, 'residual': True, 'separable': True, 'se': True, 'se_context_size': -1}
{'filters': 1024, 'repeat': 3, 'kernel': [11], 'stride': [1], 'dilation': [1], 'dropout': 0.1, 'residual': True, 'separable': True, 'se': True, 'se_context_size': -1}
{'filters': 1024, 'repeat': 3, 'kernel': [15], 'stride': [1], 'dilation': [1], 'dropout': 0.1, 'residual': True, 'separable': True, 'se': True, 'se_context_size': -1}
{'filters': 3072, 'repeat': 1, 'kernel': [1], 'stride': [1], 'dilation': [1], 'dropout': 0.0, 'residual': False, 'separable': True, 'se': True, 'se_context_size': -1}


In [11]:
for k, v in decoder.items():
    if isinstance(v, list):
        print(f"{k}:")
        for value in v:
            print(f"{value}")
    else:
        print(f"{k}: {v}")

_target_: nemo.collections.asr.modules.SpeakerDecoder
feat_in: 3072
num_classes: 16681
pool_mode: attention
emb_sizes: 192


In [13]:
for k, v in preprocessor.items():
    if isinstance(v, list):
        print(f"{k}:")
        for value in v:
            print(f"{value}")
    else:
        print(f"{k}: {v}")

_target_: nemo.collections.asr.modules.AudioToMelSpectrogramPreprocessor
normalize: per_feature
window_size: 0.025
sample_rate: 16000
window_stride: 0.01
window: hann
features: 80
n_fft: 512
frame_splicing: 1
dither: 1e-05
