In [None]:
%%capture

!pip install speechbrain

In [None]:
%%file hyperparams.yaml

#Feature parameters
sample_rate: 16000
n_mels: 80

#Model parameters
n_classes: 5

#model
label_encoder: !new:speechbrain.dataio.encoder.CategoricalEncoder

compute_features: !new:speechbrain.lobes.features.Fbank
    n_mels: !ref <n_mels>

mean_var_norm: !new:speechbrain.processing.features.InputNormalization
    norm_type: sentence
    std_norm: False

embedding_model: !new:speechbrain.lobes.models.ECAPA_TDNN.ECAPA_TDNN
    input_size: !ref <n_mels>
    channels: [256, 256, 256, 256, 768]
    kernel_sizes: [5, 3, 3, 3, 1]
    dilations: [1, 2, 3, 4, 1]
    attention_channels: 128
    lin_neurons: 192

classifier: !new:speechbrain.lobes.models.ECAPA_TDNN.Classifier
    input_size: 192
    out_neurons: !ref <n_classes>

modules:
    compute_features: !ref <compute_features>
    embedding_model: !ref <embedding_model>
    classifier: !ref <classifier>
    mean_var_norm: !ref <mean_var_norm>

pretrained_path: /content

pretrainer: !new:speechbrain.utils.parameter_transfer.Pretrainer
    loadables:
        embedding_model: !ref <embedding_model>
        classifier: !ref <classifier>
        label_encoder: !ref <label_encoder>
    paths:
        embedding_model: !ref <pretrained_path>/embedding_model.ckpt
        classifier: !ref <pretrained_path>/classifier.ckpt
        label_encoder: !ref <pretrained_path>/label_encoder.txt

Overwriting hyperparams.yaml


In [None]:
%%file Counter.py

import torch
from speechbrain.inference.interfaces import Pretrained
import torchaudio
import math
from speechbrain.utils.data_utils import split_path
from speechbrain.utils.fetching import fetch

class SpeakerCounter(Pretrained):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.sample_rate = self.hparams.sample_rate

    MODULES_NEEDED = [
        "compute_features",
        "mean_var_norm",
        "embedding_model",
        "classifier",
    ]
    def encode_batch(self, wavs, wav_lens=None, normalize=False):
        # Manage single waveforms in input
        if len(wavs.shape) == 1:
            wavs = wavs.unsqueeze(0)

        # Assign full length if wav_lens is not assigned
        if wav_lens is None:
            wav_lens = torch.ones(wavs.shape[0], device=self.device)

        # Storing waveform in the specified device
        wavs, wav_lens = wavs.to(self.device), wav_lens.to(self.device)
        wavs = wavs.float()

        # Computing features and embeddings
        feats = self.mods.compute_features(wavs)
        feats = self.mods.mean_var_norm(feats, wav_lens)
        embeddings = self.mods.embedding_model(feats, wav_lens)
        return embeddings

    def classify_batch(self, wavs, wav_lens=None):
        emb = self.encode_batch(wavs, wav_lens)
        out_prob = self.mods.classifier(emb).squeeze(1)
        score, index = torch.max(out_prob, dim=-1)
        # text_lab = self.hparams.label_encoder.decode_torch(index)
        return out_prob, score, index
        # return out_prob, score, index, text_lab

    def classify_file(self, path, **kwargs):
        waveform = self.load_audio(path, **kwargs)
        # Fake a batch:
        batch = waveform.unsqueeze(0)
        rel_length = torch.tensor([1.0])
        emb = self.encode_batch(batch, rel_length)
        out_prob = self.mods.classifier(emb).squeeze(1)
        score, index = torch.max(out_prob, dim=-1)
        text_lab = self.hparams.label_encoder.decode_torch(index)
        return out_prob, score, index, text_lab

    def forward(self, wavs, wav_lens=None):
        """Runs the classification"""
        return self.classify_batch(wavs, wav_lens)

Writing Counter.py


In [None]:
import Counter as SpeakerCounter
wav_path = "/content/session_4_spk_1_mixture_004_segment.wav"
save_dir = "/content/SaveECAPAsampleinterface"
model_path = "/content"

audio_classifier = SpeakerCounter.from_hparams(source=model_path, savedir=save_dir)

res = audio_classifier.classify_file(wav_path)
res

AttributeError: module 'Counter' has no attribute 'from_hparams'

In [None]:
import torch
import torchaudio
from torchaudio.transforms import Resample

def divide_audio_into_segments(audio_path, segment_length=2):
    # Load the audio file
    sample_rate = 16000
    waveform, _ = torchaudio.load(audio_path)

    # Calculate the number of samples for the given segment length
    num_samples_per_segment = sample_rate * segment_length

    # Calculate the total number of segments, using standard Python operations for ceiling
    total_segments = int(-(-waveform.size(1) // num_samples_per_segment))  # Ceiling division

    # Process and save each segment
    for i in range(total_segments):
        # Calculate the start and end sample for the current segment
        start_sample = i * num_samples_per_segment
        end_sample = start_sample + num_samples_per_segment

        # If the end sample exceeds the waveform length, adjust it to the waveform length
        end_sample = min(end_sample, waveform.size(1))

        # Extract the segment
        segment = waveform[:, start_sample:end_sample]

        # Save the segment to a file
        segment_file_name = f'/content/segment_{i + 1}.wav'
        torchaudio.save(segment_file_name, segment, sample_rate)

# Example usage
audio_file_path = '/content/session_2_spk_2_mixture.wav'
divide_audio_into_segments(audio_file_path)

Saved: segment_1.wav
Saved: segment_2.wav
Saved: segment_3.wav
Saved: segment_4.wav
Saved: segment_5.wav
Saved: segment_6.wav
Saved: segment_7.wav
Saved: segment_8.wav
Saved: segment_9.wav
Saved: segment_10.wav
Saved: segment_11.wav
Saved: segment_12.wav
Saved: segment_13.wav
Saved: segment_14.wav
Saved: segment_15.wav
Saved: segment_16.wav
Saved: segment_17.wav
Saved: segment_18.wav
Saved: segment_19.wav
Saved: segment_20.wav
Saved: segment_21.wav
Saved: segment_22.wav
Saved: segment_23.wav
Saved: segment_24.wav
Saved: segment_25.wav
Saved: segment_26.wav
Saved: segment_27.wav
Saved: segment_28.wav
Saved: segment_29.wav
Saved: segment_30.wav
Saved: segment_31.wav
Saved: segment_32.wav
Saved: segment_33.wav
Saved: segment_34.wav
Saved: segment_35.wav
Saved: segment_36.wav
Saved: segment_37.wav
Saved: segment_38.wav
Saved: segment_39.wav
Saved: segment_40.wav
Saved: segment_41.wav
Saved: segment_42.wav
Saved: segment_43.wav
Saved: segment_44.wav
Saved: segment_45.wav
Saved: segment_46.w