<a href="https://colab.research.google.com/github/SupradeepDanturti/ConvAIProject/blob/dev2/interface_test_2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
%%capture
!pip install speechbrain

In [None]:
""" Pull folder from github """
# !git clone --filter=blob:none --no-checkout https://github.com/SupradeepDanturti/ConvAIProject
# %cd ConvAIProject
# !git sparse-checkout init --cone
# !git sparse-checkout set results
# !git checkout

""" From Google Drive """
%%capture
!pip install --upgrade --no-cache-dir gdown
!gdown 1zDgXx_npH-DigA3shNEzEOTWgsy3MRF2
!unzip results.zip

# Inference Interface for all models
### <font color='289C4E'>Table of contents<font><a class='anchor' id='top'></a>
- [Results & Comparision](#scrollTo=FQhSGXXCdy2C)
- [XVector](#scrollTo=XNSkswC8RuCZ)
- [Ecapa-TDNN](#scrollTo=R30fxJWFRxIB)
- [Selfsupervised](#scrollTo=iOxJgJE_RxOY)
- [Wav2vec2](#scrollTo=ltvGJDdaSDHo)

# Results

#wav2vec2

In [97]:
%%file hyperparams_selfsupervised_xvector.yaml

sample_rate: 16000
sslmodel_hub: facebook/wav2vec2-base
sslmodel_folder: /content/ssl_checkpoint

freeze_ssl: False
freeze_ssl_conv: True

encoder_dim: 768
emb_dim: 128
out_n_neurons: 5

label_encoder: !new:speechbrain.dataio.encoder.CategoricalEncoder
ssl_model: !new:speechbrain.lobes.models.huggingface_transformers.wav2vec2.Wav2Vec2
    source: !ref <sslmodel_hub>
    output_norm: True
    freeze: !ref <freeze_ssl>
    freeze_feature_extractor: !ref <freeze_ssl_conv>
    save_path: !ref <sslmodel_folder>

avg_pool: !new:speechbrain.nnet.pooling.StatisticsPooling
    return_std: False

# Mean and std normalization of the input features
mean_var_norm: !new:speechbrain.processing.features.InputNormalization
    norm_type: sentence
    std_norm: False

embedding_model: !new:speechbrain.lobes.models.Xvector.Xvector
    in_channels: !ref <encoder_dim>
    activation: !name:torch.nn.LeakyReLU
    tdnn_blocks: 3
    tdnn_channels: [ 64, 64, 64 ]
    tdnn_kernel_sizes: [ 5, 2, 3 ]
    tdnn_dilations: [ 1, 2, 3 ]
    lin_neurons: !ref <emb_dim>

classifier: !new:speechbrain.lobes.models.Xvector.Classifier
    input_shape: [null, null, !ref <emb_dim>]
    activation: !name:torch.nn.LeakyReLU
    lin_blocks: 1
    lin_neurons: !ref <emb_dim>
    out_neurons: !ref <out_n_neurons>

modules:
    ssl_model: !ref <ssl_model>
    mean_var_norm: !ref <mean_var_norm>
    embedding_model:  !ref <embedding_model>
    classifier: !ref <classifier>

model: !new:torch.nn.ModuleList
    - [!ref <embedding_model>, !ref <classifier>]


pretrained_path: /content/

pretrainer: !new:speechbrain.utils.parameter_transfer.Pretrainer
  loadables:
      embedding_model: !ref <embedding_model>
      classifier: !ref <classifier>
      ssl_model: !ref <ssl_model>
      model: !ref <model>
      label_encoder: !ref <label_encoder>
  paths:
      embedding_model: !ref <pretrained_path>/embedding_model.ckpt
      classifier: !ref <pretrained_path>/classifier.ckpt
      ssl_model: !ref <pretrained_path>/ssl_model.ckpt
      model: !ref <pretrained_path>/model.ckpt
      label_encoder: !ref <pretrained_path>/label_encoder.txt

Writing hyperparams_selfsupervised_xvector.yaml


# Base

In [98]:
import torch
from speechbrain.inference.interfaces import Pretrained
import torchaudio
import math
from speechbrain.utils.data_utils import split_path
from speechbrain.utils.fetching import fetch

class SpeakerCounter(Pretrained):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.sample_rate = self.hparams.sample_rate

    MODULES_NEEDED = [
        "compute_features",
        "mean_var_norm",
        "embedding_model",
        "classifier",
    ]

    def resample_waveform(self, waveform, orig_sample_rate):
        """
        Resample the waveform to a new sample rate.
        """
        if orig_sample_rate != self.sample_rate:
            resample_transform = torchaudio.transforms.Resample(orig_freq=orig_sample_rate, new_freq=self.sample_rate)
            waveform = resample_transform(waveform)
        return waveform

    def merge_overlapping_segments(self, segments):
      if not segments:
          return []
      merged = [segments[0]]
      for current in segments[1:]:
          prev = merged[-1]
          if current[0] <= prev[1]:
              if current[2] == prev[2]:
                  merged[-1] = (prev[0], max(prev[1], current[1]), prev[2])
              else:
                  merged.append(current)
          else:
              merged.append(current)
      return merged

    def refine_transitions(self, aggregated_predictions):
        """
        Refines transition times by potentially adjusting them to be at the start
        or end of segments, aiming to make the transitions smoother and more accurate.
        """
        refined_predictions = []
        for i in range(len(aggregated_predictions)):
            if i == 0:
                refined_predictions.append(aggregated_predictions[i])
                continue

            current_start, current_end, current_label = aggregated_predictions[i]
            prev_start, prev_end, prev_label = aggregated_predictions[i-1]

            if current_start - prev_end <= 1.0:
                new_start = prev_end
            else:
                new_start = current_start

            refined_predictions.append((new_start, current_end, current_label))

        return refined_predictions

    def refine_transitions_with_confidence(self, aggregated_predictions, segment_confidences):
        refined_predictions = []
        for i in range(len(aggregated_predictions)):
            if i == 0:
                refined_predictions.append(aggregated_predictions[i])
                continue

            current_start, current_end, current_label = aggregated_predictions[i]
            prev_start, prev_end, prev_label, prev_confidence = refined_predictions[-1] + (segment_confidences[i-1],)

            current_confidence = segment_confidences[i]

            if current_label != prev_label:
                if prev_confidence < current_confidence:
                    transition_point = current_start
                else:
                    transition_point = prev_end
                refined_predictions[-1] = (prev_start, transition_point, prev_label)
                refined_predictions.append((transition_point, current_end, current_label))
            else:
                if prev_confidence < current_confidence:
                    refined_predictions[-1] = (prev_start, current_end, current_label)
                else:
                    refined_predictions.append((current_start, current_end, current_label))

        return refined_predictions



    def aggregate_segments_with_overlap(self, segment_predictions):
        aggregated_predictions = []
        last_start, last_end, last_label = segment_predictions[0]

        for start, end, label in segment_predictions[1:]:
            if label == last_label and start <= last_end:
                last_end = max(last_end, end)
            else:
                aggregated_predictions.append((last_start, last_end, last_label))
                last_start, last_end, last_label = start, end, label

        aggregated_predictions.append((last_start, last_end, last_label))

        merged = self.merge_overlapping_segments(aggregated_predictions)
        return merged

    def encode_batch(self, wavs, wav_lens=None, normalize=False):
        if len(wavs.shape) == 1:
            wavs = wavs.unsqueeze(0)

        if wav_lens is None:
            wav_lens = torch.ones(wavs.shape[0], device=self.device)

        wavs, wav_lens = wavs.to(self.device), wav_lens.to(self.device)
        wavs = wavs.float()

        # Computing features and embeddings
        feats = self.mods.compute_features(wavs) #For
        # feats = self.mods.ssl_model(wavs, wav_lens) #For selfsupervised model
        feats = self.mods.mean_var_norm(feats, wav_lens)
        embeddings = self.mods.embedding_model(feats, wav_lens)
        return embeddings

    def create_segments(self, waveform, segment_length, overlap):
        num_samples = waveform.shape[1]
        segment_samples = int(segment_length * self.sample_rate)
        overlap_samples = int(overlap * self.sample_rate)
        step_samples = segment_samples - overlap_samples
        segments = []
        segment_times = []

        for start in range(0, num_samples - segment_samples + 1, step_samples):
            end = start + segment_samples
            segments.append(waveform[:, start:end])
            start_time = start / self.sample_rate
            end_time = end / self.sample_rate
            segment_times.append((start_time, end_time))

        return segments, segment_times

    def classify_file(self, path, segment_length=2.0, overlap=1.47, **kwargs):
        """Adjusted to handle overlapped segment predictions and refining transitions"""
        waveform, osr = torchaudio.load(path)
        waveform = self.resample_waveform(waveform, osr)


        """ Attempt - Overlap Segments """
        segments, segment_times = self.create_segments(waveform, segment_length, overlap)
        segment_predictions = []

        for segment, (start_time, end_time) in zip(segments, segment_times):
            rel_length = torch.tensor([1.0])
            emb = self.encode_batch(segment, rel_length)
            out_prob = self.mods.classifier(emb).squeeze(1)
            score, index = torch.max(out_prob, dim=-1)
            text_lab = self.hparams.label_encoder.decode_torch(index)
            segment_predictions.append((start_time, end_time, text_lab[0]))

        aggregated_predictions = self.aggregate_segments_with_overlap(segment_predictions)
        refined_predictions = self.refine_transitions(aggregated_predictions)
        preds = self.refine_transitions_with_confidence(aggregated_predictions , refined_predictions)


        with open("sample_segment_predictions.txt", "w") as file:
            for start_time, end_time, prediction in preds:
                speaker_text = "no speech" if str(prediction) == "0" else ("1 speaker" if str(prediction) == "1" else f"{prediction} speakers")
                print(f"{start_time:.2f}-{end_time:.2f} has {speaker_text}")
                file.write(f"{start_time:.2f}-{end_time:.2f} has {speaker_text}\n")

        """ End of Attempt - Overlap Segments """

    def forward(self, wavs, wav_lens=None):
        """Runs the classification"""
        return self.classify_file(wavs, wav_lens)

In [None]:
# from SpeakerCounter import SpeakerCounter
wav_path = "/content/session_4_spk_1_mixture_004_segment.wav"
save_dir = "/content/SaveECAPAsampleinterface"
model_path = "/content"

# Instantiate your class using from_hparams
audio_classifier = SpeakerCounter.from_hparams(source=model_path, savedir=save_dir)

audio_classifier.classify_file(wav_path)

0.00-2.00 has 1 speaker
