<a href="https://colab.research.google.com/github/SupradeepDanturti/ConvAIProject/blob/main/interface_test.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
%%capture

!pip install speechbrain

In [24]:
%%file interface_hyperparams.yaml

#Feature parameters
sample_rate: 16000
n_mels: 80

#Model parameters
n_classes: 5

#model

label_encoder: !new:speechbrain.dataio.encoder.CategoricalEncoder

compute_features: !new:speechbrain.lobes.features.Fbank
    n_mels: !ref <n_mels>

mean_var_norm: !new:speechbrain.processing.features.InputNormalization
    norm_type: sentence
    std_norm: False

embedding_model: !new:speechbrain.lobes.models.ECAPA_TDNN.ECAPA_TDNN
    input_size: !ref <n_mels>
    channels: [256, 256, 256, 256, 768]
    kernel_sizes: [5, 3, 3, 3, 1]
    dilations: [1, 2, 3, 4, 1]
    attention_channels: 128
    lin_neurons: 192

classifier: !new:speechbrain.lobes.models.ECAPA_TDNN.Classifier
    input_size: 192
    out_neurons: !ref <n_classes>

modules:
    compute_features: !ref <compute_features>
    embedding_model: !ref <embedding_model>
    classifier: !ref <classifier>
    mean_var_norm: !ref <mean_var_norm>

pretrained_path: /content

pretrainer: !new:speechbrain.utils.parameter_transfer.Pretrainer
    loadables:
        embedding_model: !ref <embedding_model>
        classifier: !ref <classifier>
        label_encoder: !ref <label_encoder>
    paths:
        embedding_model: !ref <pretrained_path>/embedding_model.ckpt
        classifier: !ref <pretrained_path>/classifier.ckpt
        label_encoder: !ref <pretrained_path>/label_encoder.txt

Overwriting interface_hyperparams.yaml


In [27]:
%%file SpeakerCounter.py

import torch
from speechbrain.inference.interfaces import Pretrained
import torchaudio
import math
from speechbrain.utils.data_utils import split_path
from speechbrain.utils.fetching import fetch

class SpeakerCounter(Pretrained):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.sample_rate = self.hparams.sample_rate

    MODULES_NEEDED = [
        "compute_features",
        "mean_var_norm",
        "embedding_model",
        "classifier",
    ]
    def encode_batch(self, wavs, wav_lens=None, normalize=False):
        # Manage single waveforms in input
        if len(wavs.shape) == 1:
            wavs = wavs.unsqueeze(0)

        # Assign full length if wav_lens is not assigned
        if wav_lens is None:
            wav_lens = torch.ones(wavs.shape[0], device=self.device)

        # Storing waveform in the specified device
        wavs, wav_lens = wavs.to(self.device), wav_lens.to(self.device)
        wavs = wavs.float()

        # Computing features and embeddings
        feats = self.mods.compute_features(wavs)
        feats = self.mods.mean_var_norm(feats, wav_lens)
        embeddings = self.mods.embedding_model(feats, wav_lens)
        return embeddings

    def classify_batch(self, wavs, wav_lens=None):
        emb = self.encode_batch(wavs, wav_lens)
        out_prob = self.mods.classifier(emb).squeeze(1)
        score, index = torch.max(out_prob, dim=-1)
        # text_lab = self.hparams.label_encoder.decode_torch(index)
        return out_prob, score, index
        # return out_prob, score, index, text_lab

    def classify_file(self, path, **kwargs):
        waveform = self.load_audio(path, **kwargs)
        # Fake a batch:
        batch = waveform.unsqueeze(0)
        rel_length = torch.tensor([1.0])
        emb = self.encode_batch(batch, rel_length)
        out_prob = self.mods.classifier(emb).squeeze(1)
        score, index = torch.max(out_prob, dim=-1)
        text_lab = self.hparams.label_encoder.decode_torch(index)
        return out_prob, score, index, text_lab

    def forward(self, wavs, wav_lens=None):
        """Runs the classification"""
        return self.classify_batch(wavs, wav_lens)

Overwriting SpeakerCounter.py


In [28]:
from SpeakerCounter import SpeakerCounter
wav_path = "/content/session_1_spk_3_mixture.wav"
save_dir = "/content/SaveECAPAsampleinterface"
model_path = "/content"

# Instantiate your class using from_hparams
audio_classifier = SpeakerCounter.from_hparams(source=model_path, savedir=save_dir)

pred = audio_classifier.classify_file(wav_path)
print(pred)

(tensor([[0.3709, 0.7996, 0.9666, 0.9996, 0.9995]]), tensor([0.9996]), tensor([3]), ['3'])
