<a href="https://colab.research.google.com/github/SupradeepDanturti/ConvAIProject/blob/main/interface_test.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
%%capture
!pip install speechbrain

In [2]:
from google.colab import drive
drive.mount('/content/drive/')

!unzip /content/drive/MyDrive/ConvAI/interface_files.zip

Mounted at /content/drive/
Archive:  /content/drive/MyDrive/ConvAI/interface_files.zip
  inflating: classifier.ckpt         
  inflating: embedding_model.ckpt    
  inflating: hyperparams.yaml        
  inflating: label_encoder.txt       
  inflating: session_2_spk_2_mixture.wav  
  inflating: session_4_spk_1_mixture_004_segment.wav  


In [3]:
import torch
from speechbrain.inference.interfaces import Pretrained
import torchaudio
import math
from speechbrain.utils.data_utils import split_path
from speechbrain.utils.fetching import fetch

class SpeakerCounter(Pretrained):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.sample_rate = self.hparams.sample_rate

    MODULES_NEEDED = [
        "compute_features",
        "mean_var_norm",
        "embedding_model",
        "classifier",
    ]

    def resample_waveform(self, waveform, orig_sample_rate, new_sample_rate=16000):
        """
        Resample the waveform to a new sample rate.
        """
        if orig_sample_rate != new_sample_rate:
            resample_transform = torchaudio.transforms.Resample(orig_freq=orig_sample_rate, new_freq=new_sample_rate)
            waveform = resample_transform(waveform)
        return waveform

    def encode_batch(self, wavs, wav_lens=None, normalize=False):
        # Manage single waveforms in input
        if len(wavs.shape) == 1:
            wavs = wavs.unsqueeze(0)

        # Assign full length if wav_lens is not assigned
        if wav_lens is None:
            wav_lens = torch.ones(wavs.shape[0], device=self.device)

        # Storing waveform in the specified device
        wavs, wav_lens = wavs.to(self.device), wav_lens.to(self.device)
        wavs = wavs.float()

        # Computing features and embeddings
        feats = self.mods.compute_features(wavs)
        feats = self.mods.mean_var_norm(feats, wav_lens)
        embeddings = self.mods.embedding_model(feats, wav_lens)
        return embeddings

    def create_segments(self, waveform, segment_length=2.0):
        num_samples = waveform.shape[1]
        segment_samples = int(segment_length * self.sample_rate)
        segments = []
        segment_times = []

        for start in range(0, num_samples, segment_samples):
            end = start + segment_samples
            if end > num_samples:
                end = num_samples
            segments.append(waveform[:, start:end])
            segment_times.append((start / self.sample_rate, end / self.sample_rate))

        return segments, segment_times

    def classify_file(self, path, segment_length=2.0, **kwargs):
        waveform, osr = torchaudio.load(path)
        waveform = self.resample_waveform(waveform, osr)

        segments, segment_times = self.create_segments(waveform, segment_length)
        with open("segment_predictions.txt", "w") as file:
          for segment, (start_time, end_time) in zip(segments, segment_times):

            rel_length = torch.tensor([1.0])

            emb = self.encode_batch(segment, rel_length)
            out_prob = self.mods.classifier(emb).squeeze(1)
            score, index = torch.max(out_prob, dim=-1)
            text_lab = self.hparams.label_encoder.decode_torch(index)
            file.write(f"{start_time:.2f} {end_time:.2f} {text_lab}\n")

    def forward(self, wavs, wav_lens=None):
        """Runs the classification"""
        return self.classify_batch(wavs, wav_lens)

In [4]:
# from SpeakerCounter import SpeakerCounter
wav_path = "/content/session_2_spk_2_mixture.wav"
save_dir = "/content/SaveECAPAsampleinterface"
model_path = "/content"

# Instantiate your class using from_hparams
audio_classifier = SpeakerCounter.from_hparams(source=model_path, savedir=save_dir)

segments  = audio_classifier.classify_file(wav_path)

In [12]:
def encode_batch(self, wavs, wav_lens=None, normalize=False):
    if len(wavs.shape) == 1:
        wavs = wavs.unsqueeze(0)

    wavs = wavs.to(self.device).float()

    # Compute features without passing `lengths` if it causes issues
    feats = self.mods.compute_features(wavs)
    # Use `lengths` here if the layer expects it, otherwise adjust as needed
    feats = self.mods.mean_var_norm(feats, wav_lens) if wav_lens is not None else feats
    embeddings = self.mods.embedding_model(feats)
    return embeddings

In [14]:
segments[1].shape

TypeError: encode_batch() missing 1 required positional argument: 'self'

In [9]:
!rm -rf /content/SaveECAPAsampleinterface

In [33]:
%%file SpeakerCounter.py

import torch
from speechbrain.inference.interfaces import Pretrained
import torchaudio
import math
from speechbrain.utils.data_utils import split_path
from speechbrain.utils.fetching import fetch

class SpeakerCounter(Pretrained):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.sample_rate = self.hparams.sample_rate

    MODULES_NEEDED = [
        "compute_features",
        "mean_var_norm",
        "embedding_model",
        "classifier",
    ]
    def encode_batch(self, wavs, wav_lens=None, normalize=False):


        # Computing features and embeddings
        feats = self.mods.compute_features(wavs)
        feats = self.mods.mean_var_norm(feats, wav_lens)
        embeddings = self.mods.embedding_model(feats, wav_lens)
        return embeddings

    def classify_batch(self, wavs, wav_lens=None):
        emb = self.encode_batch(wavs, wav_lens)
        out_prob = self.mods.classifier(emb).squeeze(1)
        score, index = torch.max(out_prob, dim=-1)
        # text_lab = self.hparams.label_encoder.decode_torch(index)
        return out_prob, score, index
        # return out_prob, score, index, text_lab

    def classify_file(self, path, **kwargs):
        waveform = self.load_audio(path, **kwargs)
        # Fake a batch:
        batch = waveform.unsqueeze(0)
        rel_length = torch.tensor([1.0])
        emb = self.encode_batch(batch, rel_length)
        out_prob = self.mods.classifier(emb).squeeze(1)
        score, index = torch.max(out_prob, dim=-1)
        text_lab = self.hparams.label_encoder.decode_torch(index)
        return out_prob, score, index, text_lab

    def forward(self, wavs, wav_lens=None):
        """Runs the classification"""
        return self.classify_batch(wavs, wav_lens)

Overwriting SpeakerCounter.py


In [34]:
from SpeakerCounter import SpeakerCounter
wav_path = "/content/session_2_spk_2_mixture.wav"
# wav_path = "/content/session_4_spk_1_mixture_004_segment.wav"
save_dir = "/content/SaveECAPAsampleinterface"
model_path = "/content"

# Instantiate your class using from_hparams
audio_classifier = SpeakerCounter.from_hparams(source=model_path, savedir=save_dir)

# audio_classifier.hparams.label_encoder.ignore_len()
# signal, fs = torchaudio.load(wav_path)
# # pred = audio_classifier.classify_file(wav_path)
# embeddings = audio_classifier.encode_batch(signal)
# prediction = audio_classifier.classify_batch(signal)
# print(prediction)

"""or """
audio_classifier.classify_file(wav_path)

RuntimeError: Failed to open the input "session_2_spk_2_mixture.wav" (Too many levels of symbolic links).
Exception raised from get_input_format_context at /__w/audio/audio/pytorch/audio/src/libtorio/ffmpeg/stream_reader/stream_reader.cpp:42 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::string) + 0x57 (0x7da782cced87 in /usr/local/lib/python3.10/dist-packages/torch/lib/libc10.so)
frame #1: c10::detail::torchCheckFail(char const*, char const*, unsigned int, std::string const&) + 0x64 (0x7da782c7f75f in /usr/local/lib/python3.10/dist-packages/torch/lib/libc10.so)
frame #2: <unknown function> + 0x42904 (0x7da7828ca904 in /usr/local/lib/python3.10/dist-packages/torio/lib/libtorio_ffmpeg4.so)
frame #3: torio::io::StreamingMediaDecoder::StreamingMediaDecoder(std::string const&, std::optional<std::string> const&, std::optional<std::map<std::string, std::string, std::less<std::string>, std::allocator<std::pair<std::string const, std::string> > > > const&) + 0x14 (0x7da7828cd304 in /usr/local/lib/python3.10/dist-packages/torio/lib/libtorio_ffmpeg4.so)
frame #4: <unknown function> + 0x3a58e (0x7da6bdc3a58e in /usr/local/lib/python3.10/dist-packages/torio/lib/_torio_ffmpeg4.so)
frame #5: <unknown function> + 0x32147 (0x7da6bdc32147 in /usr/local/lib/python3.10/dist-packages/torio/lib/_torio_ffmpeg4.so)
frame #6: <unknown function> + 0x15a10e (0x5bebc993410e in /usr/bin/python3)
frame #7: _PyObject_MakeTpCall + 0x25b (0x5bebc992aa7b in /usr/bin/python3)
frame #8: <unknown function> + 0x168c20 (0x5bebc9942c20 in /usr/bin/python3)
frame #9: <unknown function> + 0x165087 (0x5bebc993f087 in /usr/bin/python3)
frame #10: <unknown function> + 0x150e2b (0x5bebc992ae2b in /usr/bin/python3)
frame #11: <unknown function> + 0xf244 (0x7da7bf592244 in /usr/local/lib/python3.10/dist-packages/torchaudio/lib/_torchaudio.so)
frame #12: _PyObject_MakeTpCall + 0x25b (0x5bebc992aa7b in /usr/bin/python3)
frame #13: _PyEval_EvalFrameDefault + 0x6a79 (0x5bebc9923629 in /usr/bin/python3)
frame #14: _PyObject_FastCallDictTstate + 0xc4 (0x5bebc9929c14 in /usr/bin/python3)
frame #15: <unknown function> + 0x164a64 (0x5bebc993ea64 in /usr/bin/python3)
frame #16: _PyObject_MakeTpCall + 0x1fc (0x5bebc992aa1c in /usr/bin/python3)
frame #17: _PyEval_EvalFrameDefault + 0x6a79 (0x5bebc9923629 in /usr/bin/python3)
frame #18: _PyFunction_Vectorcall + 0x7c (0x5bebc99349fc in /usr/bin/python3)
frame #19: _PyEval_EvalFrameDefault + 0x6bd (0x5bebc991d26d in /usr/bin/python3)
frame #20: _PyFunction_Vectorcall + 0x7c (0x5bebc99349fc in /usr/bin/python3)
frame #21: _PyEval_EvalFrameDefault + 0x614a (0x5bebc9922cfa in /usr/bin/python3)
frame #22: _PyFunction_Vectorcall + 0x7c (0x5bebc99349fc in /usr/bin/python3)
frame #23: _PyEval_EvalFrameDefault + 0x198c (0x5bebc991e53c in /usr/bin/python3)
frame #24: <unknown function> + 0x16893e (0x5bebc994293e in /usr/bin/python3)
frame #25: _PyEval_EvalFrameDefault + 0x2a27 (0x5bebc991f5d7 in /usr/bin/python3)
frame #26: <unknown function> + 0x1687f1 (0x5bebc99427f1 in /usr/bin/python3)
frame #27: _PyEval_EvalFrameDefault + 0x614a (0x5bebc9922cfa in /usr/bin/python3)
frame #28: <unknown function> + 0x13f9c6 (0x5bebc99199c6 in /usr/bin/python3)
frame #29: PyEval_EvalCode + 0x86 (0x5bebc9a0f256 in /usr/bin/python3)
frame #30: <unknown function> + 0x23ae2d (0x5bebc9a14e2d in /usr/bin/python3)
frame #31: <unknown function> + 0x15ac59 (0x5bebc9934c59 in /usr/bin/python3)
frame #32: _PyEval_EvalFrameDefault + 0x6bd (0x5bebc991d26d in /usr/bin/python3)
frame #33: <unknown function> + 0x177ff0 (0x5bebc9951ff0 in /usr/bin/python3)
frame #34: _PyEval_EvalFrameDefault + 0x2568 (0x5bebc991f118 in /usr/bin/python3)
frame #35: <unknown function> + 0x177ff0 (0x5bebc9951ff0 in /usr/bin/python3)
frame #36: _PyEval_EvalFrameDefault + 0x2568 (0x5bebc991f118 in /usr/bin/python3)
frame #37: <unknown function> + 0x177ff0 (0x5bebc9951ff0 in /usr/bin/python3)
frame #38: <unknown function> + 0x2557af (0x5bebc9a2f7af in /usr/bin/python3)
frame #39: <unknown function> + 0x1662ca (0x5bebc99402ca in /usr/bin/python3)
frame #40: _PyEval_EvalFrameDefault + 0x8ac (0x5bebc991d45c in /usr/bin/python3)
frame #41: _PyFunction_Vectorcall + 0x7c (0x5bebc99349fc in /usr/bin/python3)
frame #42: _PyEval_EvalFrameDefault + 0x6bd (0x5bebc991d26d in /usr/bin/python3)
frame #43: _PyFunction_Vectorcall + 0x7c (0x5bebc99349fc in /usr/bin/python3)
frame #44: _PyEval_EvalFrameDefault + 0x8ac (0x5bebc991d45c in /usr/bin/python3)
frame #45: <unknown function> + 0x1687f1 (0x5bebc99427f1 in /usr/bin/python3)
frame #46: PyObject_Call + 0x122 (0x5bebc9943492 in /usr/bin/python3)
frame #47: _PyEval_EvalFrameDefault + 0x2a27 (0x5bebc991f5d7 in /usr/bin/python3)
frame #48: <unknown function> + 0x1687f1 (0x5bebc99427f1 in /usr/bin/python3)
frame #49: _PyEval_EvalFrameDefault + 0x198c (0x5bebc991e53c in /usr/bin/python3)
frame #50: <unknown function> + 0x200175 (0x5bebc99da175 in /usr/bin/python3)
frame #51: <unknown function> + 0x15ac59 (0x5bebc9934c59 in /usr/bin/python3)
frame #52: <unknown function> + 0x236bc5 (0x5bebc9a10bc5 in /usr/bin/python3)
frame #53: <unknown function> + 0x2b2572 (0x5bebc9a8c572 in /usr/bin/python3)
frame #54: <unknown function> + 0x14d99b (0x5bebc992799b in /usr/bin/python3)
frame #55: _PyEval_EvalFrameDefault + 0x6bd (0x5bebc991d26d in /usr/bin/python3)
frame #56: _PyFunction_Vectorcall + 0x7c (0x5bebc99349fc in /usr/bin/python3)
frame #57: _PyEval_EvalFrameDefault + 0x8ac (0x5bebc991d45c in /usr/bin/python3)
frame #58: <unknown function> + 0x200175 (0x5bebc99da175 in /usr/bin/python3)
frame #59: <unknown function> + 0x15ac59 (0x5bebc9934c59 in /usr/bin/python3)
frame #60: <unknown function> + 0x236bc5 (0x5bebc9a10bc5 in /usr/bin/python3)
frame #61: <unknown function> + 0x2b2572 (0x5bebc9a8c572 in /usr/bin/python3)
frame #62: <unknown function> + 0x14d99b (0x5bebc992799b in /usr/bin/python3)
frame #63: _PyEval_EvalFrameDefault + 0x6bd (0x5bebc991d26d in /usr/bin/python3)


In [None]:
import torchaudio
def divide_audio_into_segments(audio_path, segment_length=2):
    # Load the audio file
    sample_rate = 16000
    waveform, _ = torchaudio.load(audio_path)

    # Calculate the number of samples for the given segment length
    num_samples_per_segment = sample_rate * segment_length

    # Calculate the total number of segments, using standard Python operations for ceiling
    total_segments = int(-(-waveform.size(1) // num_samples_per_segment))  # Ceiling division

    # Process and save each segment
    for i in range(total_segments):
        # Calculate the start and end sample for the current segment
        start_sample = i * num_samples_per_segment
        end_sample = start_sample + num_samples_per_segment

        # If the end sample exceeds the waveform length, adjust it to the waveform length
        end_sample = min(end_sample, waveform.size(1))

        # Extract the segment
        segment = waveform[:, start_sample:end_sample]

        # Save the segment to a file
        segment_file_name = f'/content/samples/segment_{i + 1}.wav'
        torchaudio.save(segment_file_name, segment, sample_rate)

divide_audio_into_segments(wav_path)