In [None]:
# MIT License
#
# Copyright (c) 2021 CNRS
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

import warnings
from functools import cached_property
from pathlib import Path
from typing import Optional, Text, Union

import numpy as np
import torch
import torch.nn.functional as F
import torchaudio.compliance.kaldi as kaldi
from huggingface_hub import hf_hub_download
from huggingface_hub.utils import RepositoryNotFoundError
from torch.nn.utils.rnn import pad_sequence

from pyannote.audio import Inference, Model, Pipeline
from pyannote.audio.core.inference import BaseInference
from pyannote.audio.core.io import AudioFile
from pyannote.audio.core.model import CACHE_DIR
from pyannote.audio.pipelines.utils import PipelineModel, get_model

try:
    from speechbrain.pretrained import (
        EncoderClassifier as SpeechBrain_EncoderClassifier,
    )

    SPEECHBRAIN_IS_AVAILABLE = True
except ImportError:
    SPEECHBRAIN_IS_AVAILABLE = False

try:
    from nemo.collections.asr.models import (
        EncDecSpeakerLabelModel as NeMo_EncDecSpeakerLabelModel,
    )

    NEMO_IS_AVAILABLE = True
except ImportError:
    NEMO_IS_AVAILABLE = False

try:
    import onnxruntime as ort

    ONNX_IS_AVAILABLE = True
except ImportError:
    ONNX_IS_AVAILABLE = False


class NeMoPretrainedSpeakerEmbedding(BaseInference):
    def __init__(
        self,
        embedding: Text = "nvidia/speakerverification_en_titanet_large",
        device: Optional[torch.device] = None,
    ):
        if not NEMO_IS_AVAILABLE:
            raise ImportError(
                f"'NeMo' must be installed to use '{embedding}' embeddings. "
                "Visit https://nvidia.github.io/NeMo/ for installation instructions."
            )

        super().__init__()
        self.embedding = embedding
        self.device = device or torch.device("cpu")

        self.model_ = NeMo_EncDecSpeakerLabelModel.from_pretrained(self.embedding)
        self.model_.freeze()
        self.model_.to(self.device)

    def to(self, device: torch.device):
        if not isinstance(device, torch.device):
            raise TypeError(
                f"`device` must be an instance of `torch.device`, got `{type(device).__name__}`"
            )

        self.model_.to(device)
        self.device = device
        return self

    @cached_property
    def sample_rate(self) -> int:
        return self.model_._cfg.train_ds.get("sample_rate", 16000)

    @cached_property
    def dimension(self) -> int:
        input_signal = torch.rand(1, self.sample_rate).to(self.device)
        input_signal_length = torch.tensor([self.sample_rate]).to(self.device)
        _, embeddings = self.model_(
            input_signal=input_signal, input_signal_length=input_signal_length
        )
        _, dimension = embeddings.shape
        return dimension

    @cached_property
    def metric(self) -> str:
        return "cosine"

    @cached_property
    def min_num_samples(self) -> int:
        lower, upper = 2, round(0.5 * self.sample_rate)
        middle = (lower + upper) // 2
        while lower + 1 < upper:
            try:
                input_signal = torch.rand(1, middle).to(self.device)
                input_signal_length = torch.tensor([middle]).to(self.device)

                _ = self.model_(
                    input_signal=input_signal, input_signal_length=input_signal_length
                )

                upper = middle
            except RuntimeError:
                lower = middle

            middle = (lower + upper) // 2

        return upper

    def __call__(
        self, waveforms: torch.Tensor, masks: Optional[torch.Tensor] = None
    ) -> np.ndarray:
        """

        Parameters
        ----------
        waveforms : (batch_size, num_channels, num_samples)
            Only num_channels == 1 is supported.
        masks : (batch_size, num_samples), optional

        Returns
        -------
        embeddings : (batch_size, dimension)

        """

        batch_size, num_channels, num_samples = waveforms.shape
        assert num_channels == 1

        waveforms = waveforms.squeeze(dim=1)

        if masks is None:
            signals = waveforms.squeeze(dim=1)
            wav_lens = signals.shape[1] * torch.ones(batch_size)

        else:
            batch_size_masks, _ = masks.shape
            assert batch_size == batch_size_masks

            # TODO: speed up the creation of "signals"
            # preliminary profiling experiments show
            # that it accounts for 15% of __call__
            # (the remaining 85% being the actual forward pass)

            imasks = F.interpolate(
                masks.unsqueeze(dim=1), size=num_samples, mode="nearest"
            ).squeeze(dim=1)

            imasks = imasks > 0.5

            signals = pad_sequence(
                [waveform[imask] for waveform, imask in zip(waveforms, imasks)],
                batch_first=True,
            )

            wav_lens = imasks.sum(dim=1)

        max_len = wav_lens.max()

        # corner case: every signal is too short
        if max_len < self.min_num_samples:
            return np.NAN * np.zeros((batch_size, self.dimension))

        too_short = wav_lens < self.min_num_samples
        wav_lens[too_short] = max_len

        _, embeddings = self.model_(
            input_signal=waveforms.to(self.device),
            input_signal_length=wav_lens.to(self.device),
        )

        embeddings = embeddings.cpu().numpy()
        embeddings[too_short.cpu().numpy()] = np.NAN

        return embeddings


class SpeechBrainPretrainedSpeakerEmbedding(BaseInference):
    """Pretrained SpeechBrain speaker embedding

    Parameters
    ----------
    embedding : str
        Name of SpeechBrain model
    device : torch.device, optional
        Device
    use_auth_token : str, optional
        When loading private huggingface.co models, set `use_auth_token`
        to True or to a string containing your hugginface.co authentication
        token that can be obtained by running `huggingface-cli login`

    Usage
    -----
    >>> get_embedding = SpeechBrainPretrainedSpeakerEmbedding("speechbrain/spkrec-ecapa-voxceleb")
    >>> assert waveforms.ndim == 3
    >>> batch_size, num_channels, num_samples = waveforms.shape
    >>> assert num_channels == 1
    >>> embeddings = get_embedding(waveforms)
    >>> assert embeddings.ndim == 2
    >>> assert embeddings.shape[0] == batch_size

    >>> assert binary_masks.ndim == 1
    >>> assert binary_masks.shape[0] == batch_size
    >>> embeddings = get_embedding(waveforms, masks=binary_masks)
    """

    def __init__(
        self,
        embedding: Text = "speechbrain/spkrec-ecapa-voxceleb",
        device: Optional[torch.device] = None,
        use_auth_token: Union[Text, None] = None,
    ):
        if not SPEECHBRAIN_IS_AVAILABLE:
            raise ImportError(
                f"'speechbrain' must be installed to use '{embedding}' embeddings. "
                "Visit https://speechbrain.github.io for installation instructions."
            )

        super().__init__()
        if "@" in embedding:
            self.embedding = embedding.split("@")[0]
            self.revision = embedding.split("@")[1]
        else:
            self.embedding = embedding
            self.revision = None
        self.device = device or torch.device("cpu")
        self.use_auth_token = use_auth_token

        self.classifier_ = SpeechBrain_EncoderClassifier.from_hparams(
            source=self.embedding,
            savedir=f"{CACHE_DIR}/speechbrain",
            run_opts={"device": self.device},
            use_auth_token=self.use_auth_token,
            revision=self.revision,
        )

    def to(self, device: torch.device):
        if not isinstance(device, torch.device):
            raise TypeError(
                f"`device` must be an instance of `torch.device`, got `{type(device).__name__}`"
            )

        self.classifier_ = SpeechBrain_EncoderClassifier.from_hparams(
            source=self.embedding,
            savedir=f"{CACHE_DIR}/speechbrain",
            run_opts={"device": device},
            use_auth_token=self.use_auth_token,
            revision=self.revision,
        )
        self.device = device
        return self

    @cached_property
    def sample_rate(self) -> int:
        return self.classifier_.audio_normalizer.sample_rate

    @cached_property
    def dimension(self) -> int:
        dummy_waveforms = torch.rand(1, 16000).to(self.device)
        *_, dimension = self.classifier_.encode_batch(dummy_waveforms).shape
        return dimension

    @cached_property
    def metric(self) -> str:
        return "cosine"

    @cached_property
    def min_num_samples(self) -> int:
        with torch.inference_mode():
            lower, upper = 2, round(0.5 * self.sample_rate)
            middle = (lower + upper) // 2
            while lower + 1 < upper:
                try:
                    _ = self.classifier_.encode_batch(
                        torch.randn(1, middle).to(self.device)
                    )
                    upper = middle
                except RuntimeError:
                    lower = middle

                middle = (lower + upper) // 2

        return upper

    def __call__(
        self, waveforms: torch.Tensor, masks: Optional[torch.Tensor] = None
    ) -> np.ndarray:
        """

        Parameters
        ----------
        waveforms : (batch_size, num_channels, num_samples)
            Only num_channels == 1 is supported.
        masks : (batch_size, num_samples), optional

        Returns
        -------
        embeddings : (batch_size, dimension)

        """

        batch_size, num_channels, num_samples = waveforms.shape
        assert num_channels == 1

        waveforms = waveforms.squeeze(dim=1)

        if masks is None:
            signals = waveforms.squeeze(dim=1)
            wav_lens = signals.shape[1] * torch.ones(batch_size)

        else:
            batch_size_masks, _ = masks.shape
            assert batch_size == batch_size_masks

            # TODO: speed up the creation of "signals"
            # preliminary profiling experiments show
            # that it accounts for 15% of __call__
            # (the remaining 85% being the actual forward pass)

            imasks = F.interpolate(
                masks.unsqueeze(dim=1), size=num_samples, mode="nearest"
            ).squeeze(dim=1)

            imasks = imasks > 0.5

            signals = pad_sequence(
                [
                    waveform[imask].contiguous()
                    for waveform, imask in zip(waveforms, imasks)
                ],
                batch_first=True,
            )

            wav_lens = imasks.sum(dim=1)

        max_len = wav_lens.max()

        # corner case: every signal is too short
        if max_len < self.min_num_samples:
            return np.NAN * np.zeros((batch_size, self.dimension))

        too_short = wav_lens < self.min_num_samples
        wav_lens = wav_lens / max_len
        wav_lens[too_short] = 1.0

        embeddings = (
            self.classifier_.encode_batch(signals, wav_lens=wav_lens)
            .squeeze(dim=1)
            .cpu()
            .numpy()
        )

        embeddings[too_short.cpu().numpy()] = np.NAN

        return embeddings


class ONNXWeSpeakerPretrainedSpeakerEmbedding(BaseInference):
    """Pretrained WeSpeaker speaker embedding

    Parameters
    ----------
    embedding : str
        Path to WeSpeaker pretrained speaker embedding
    device : torch.device, optional
        Device

    Usage
    -----
    >>> get_embedding = ONNXWeSpeakerPretrainedSpeakerEmbedding("hbredin/wespeaker-voxceleb-resnet34-LM")
    >>> assert waveforms.ndim == 3
    >>> batch_size, num_channels, num_samples = waveforms.shape
    >>> assert num_channels == 1
    >>> embeddings = get_embedding(waveforms)
    >>> assert embeddings.ndim == 2
    >>> assert embeddings.shape[0] == batch_size

    >>> assert binary_masks.ndim == 1
    >>> assert binary_masks.shape[0] == batch_size
    >>> embeddings = get_embedding(waveforms, masks=binary_masks)
    """

    def __init__(
        self,
        embedding: Text = "hbredin/wespeaker-voxceleb-resnet34-LM",
        device: Optional[torch.device] = None,
    ):
        if not ONNX_IS_AVAILABLE:
            raise ImportError(
                f"'onnxruntime' must be installed to use '{embedding}' embeddings."
            )

        super().__init__()

        if not Path(embedding).exists():
            try:
                embedding = hf_hub_download(
                    repo_id=embedding,
                    filename="speaker-embedding.onnx",
                )
            except RepositoryNotFoundError:
                raise ValueError(
                    f"Could not find '{embedding}' on huggingface.co nor on local disk."
                )

        self.embedding = embedding

        self.to(device or torch.device("cpu"))

    def to(self, device: torch.device):
        if not isinstance(device, torch.device):
            raise TypeError(
                f"`device` must be an instance of `torch.device`, got `{type(device).__name__}`"
            )

        if device.type == "cpu":
            providers = ["CPUExecutionProvider"]
        elif device.type == "cuda":
            providers = [
                (
                    "CUDAExecutionProvider",
                    {
                        "cudnn_conv_algo_search": "DEFAULT",  # EXHAUSTIVE / HEURISTIC / DEFAULT
                    },
                )
            ]
        else:
            warnings.warn(
                f"Unsupported device type: {device.type}, falling back to CPU"
            )
            device = torch.device("cpu")
            providers = ["CPUExecutionProvider"]

        sess_options = ort.SessionOptions()
        sess_options.inter_op_num_threads = 1
        sess_options.intra_op_num_threads = 1
        self.session_ = ort.InferenceSession(
            self.embedding, sess_options=sess_options, providers=providers
        )

        self.device = device
        return self

    @cached_property
    def sample_rate(self) -> int:
        return 16000

    @cached_property
    def dimension(self) -> int:
        dummy_waveforms = torch.rand(1, 1, 16000)
        features = self.compute_fbank(dummy_waveforms)
        embeddings = self.session_.run(
            output_names=["embs"], input_feed={"feats": features.numpy()}
        )[0]
        _, dimension = embeddings.shape
        return dimension

    @cached_property
    def metric(self) -> str:
        return "cosine"

    @cached_property
    def min_num_samples(self) -> int:
        lower, upper = 2, round(0.5 * self.sample_rate)
        middle = (lower + upper) // 2
        while lower + 1 < upper:
            try:
                features = self.compute_fbank(torch.randn(1, 1, middle))

            except AssertionError:
                lower = middle
                middle = (lower + upper) // 2
                continue

            embeddings = self.session_.run(
                output_names=["embs"], input_feed={"feats": features.numpy()}
            )[0]

            if np.any(np.isnan(embeddings)):
                lower = middle
            else:
                upper = middle
            middle = (lower + upper) // 2

        return upper

    @cached_property
    def min_num_frames(self) -> int:
        return self.compute_fbank(torch.randn(1, 1, self.min_num_samples)).shape[1]

    def compute_fbank(
        self,
        waveforms: torch.Tensor,
        num_mel_bins: int = 80,
        frame_length: int = 25,
        frame_shift: int = 10,
        dither: float = 0.0,
    ) -> torch.Tensor:
        """Extract fbank features

        Parameters
        ----------
        waveforms : (batch_size, num_channels, num_samples)

        Returns
        -------
        fbank : (batch_size, num_frames, num_mel_bins)

        Source: https://github.com/wenet-e2e/wespeaker/blob/45941e7cba2c3ea99e232d02bedf617fc71b0dad/wespeaker/bin/infer_onnx.py#L30C1-L50
        """

        waveforms = waveforms * (1 << 15)
        features = torch.stack(
            [
                kaldi.fbank(
                    waveform,
                    num_mel_bins=num_mel_bins,
                    frame_length=frame_length,
                    frame_shift=frame_shift,
                    dither=dither,
                    sample_frequency=self.sample_rate,
                    window_type="hamming",
                    use_energy=False,
                )
                for waveform in waveforms
            ]
        )

        return features - torch.mean(features, dim=1, keepdim=True)

    def __call__(
        self, waveforms: torch.Tensor, masks: Optional[torch.Tensor] = None
    ) -> np.ndarray:
        """

        Parameters
        ----------
        waveforms : (batch_size, num_channels, num_samples)
            Only num_channels == 1 is supported.
        masks : (batch_size, num_samples), optional

        Returns
        -------
        embeddings : (batch_size, dimension)

        """

        batch_size, num_channels, num_samples = waveforms.shape
        assert num_channels == 1

        features = self.compute_fbank(waveforms.to(self.device))
        _, num_frames, _ = features.shape

        if masks is None:
            embeddings = self.session_.run(
                output_names=["embs"], input_feed={"feats": features.numpy(force=True)}
            )[0]

            return embeddings

        batch_size_masks, _ = masks.shape
        assert batch_size == batch_size_masks

        imasks = F.interpolate(
            masks.unsqueeze(dim=1), size=num_frames, mode="nearest"
        ).squeeze(dim=1)

        imasks = imasks > 0.5

        embeddings = np.NAN * np.zeros((batch_size, self.dimension))

        for f, (feature, imask) in enumerate(zip(features, imasks)):
            masked_feature = feature[imask]
            if masked_feature.shape[0] < self.min_num_frames:
                continue

            embeddings[f] = self.session_.run(
                output_names=["embs"],
                input_feed={"feats": masked_feature.numpy(force=True)[None]},
            )[0][0]

        return embeddings


class PyannoteAudioPretrainedSpeakerEmbedding(BaseInference):
    """Pretrained pyannote.audio speaker embedding

    Parameters
    ----------
    embedding : PipelineModel
        pyannote.audio model
    device : torch.device, optional
        Device
    use_auth_token : str, optional
        When loading private huggingface.co models, set `use_auth_token`
        to True or to a string containing your hugginface.co authentication
        token that can be obtained by running `huggingface-cli login`

    Usage
    -----
    >>> get_embedding = PyannoteAudioPretrainedSpeakerEmbedding("pyannote/embedding")
    >>> assert waveforms.ndim == 3
    >>> batch_size, num_channels, num_samples = waveforms.shape
    >>> assert num_channels == 1
    >>> embeddings = get_embedding(waveforms)
    >>> assert embeddings.ndim == 2
    >>> assert embeddings.shape[0] == batch_size

    >>> assert masks.ndim == 1
    >>> assert masks.shape[0] == batch_size
    >>> embeddings = get_embedding(waveforms, masks=masks)
    """

    def __init__(
        self,
        embedding: PipelineModel = "pyannote/embedding",
        device: Optional[torch.device] = None,
        use_auth_token: Union[Text, None] = None,
    ):
        super().__init__()
        self.embedding = embedding
        self.device = device or torch.device("cpu")

        self.model_: Model = get_model(self.embedding, use_auth_token=use_auth_token)
        self.model_.eval()
        self.model_.to(self.device)

    def to(self, device: torch.device):
        if not isinstance(device, torch.device):
            raise TypeError(
                f"`device` must be an instance of `torch.device`, got `{type(device).__name__}`"
            )

        self.model_.to(device)
        self.device = device
        return self

    @cached_property
    def sample_rate(self) -> int:
        return self.model_.audio.sample_rate

    @cached_property
    def dimension(self) -> int:
        return self.model_.dimension

    @cached_property
    def metric(self) -> str:
        return "cosine"

    @cached_property
    def min_num_samples(self) -> int:
        with torch.inference_mode():
            lower, upper = 2, round(0.5 * self.sample_rate)
            middle = (lower + upper) // 2
            while lower + 1 < upper:
                try:
                    _ = self.model_(torch.randn(1, 1, middle).to(self.device))
                    upper = middle
                except Exception:
                    lower = middle

                middle = (lower + upper) // 2

        return upper

    def __call__(
        self, waveforms: torch.Tensor, masks: Optional[torch.Tensor] = None
    ) -> np.ndarray:
        with torch.inference_mode():
            if masks is None:
                embeddings = self.model_(waveforms.to(self.device))
            else:
                with warnings.catch_warnings():
                    warnings.simplefilter("ignore")

                    print("calling_eb model with, waveforms, masks",waveforms.size(),masks.size())


                    embeddings = self.model_(
                        waveforms.to(self.device), weights=masks.to(self.device)
                    )
        return embeddings.cpu().numpy()


#!/usr/bin/env python3
import argparse
import logging
import sys
from distutils.version import LooseVersion
from itertools import groupby
from pathlib import Path
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union

import numpy as np
import torch
from typeguard import check_argument_types, check_return_type

from espnet2.fileio.npy_scp import NpyScpWriter
from espnet2.tasks.spk import SpeakerTask
from espnet2.torch_utils.device_funcs import to_device
from espnet2.torch_utils.set_all_random_seed import set_all_random_seed
from espnet2.utils import config_argparse
from espnet2.utils.types import str2bool, str2triple_str, str_or_none
from espnet.utils.cli_utils import get_commandline_args
from pyannote.audio.models.blocks.pooling import StatsPool


class Speech2Embedding:
    """Speech2Embedding class

    Examples:
        >>> import soundfile
        >>> speech2spkembed = Speech2Embedding("spk_config.yml", "spk.pth")
        >>> audio, rate = soundfile.read("speech.wav")
        >>> speech2spkembed(audio)

    """

    def __init__(
        self,
        train_config: Union[Path, str] = None,
        model_file: Union[Path, str] = None,
        device: str = "cpu",
        dtype: str = "float32",
        batch_size: int = 1,
        min_num_samples: int = 200,
        dimension: int = 192,
    ):
        assert check_argument_types()

        spk_model, spk_train_args = SpeakerTask.build_model_from_file(
            train_config, model_file, device
        )
        self.spk_model = spk_model.eval()
        self.spk_train_args = spk_train_args
        self.device = device
        self.dtype = dtype
        self.batch_size = batch_size
        self.min_num_samples = min_num_samples
        self.dimension = dimension


    # @torch.no_grad()
    # def __call__(self, speech: Union[torch.Tensor, np.ndarray],weights: torch.Tensor) -> torch.Tensor:
    #     """Inference

    #     Args:
    #         speech: Input speech data

    #     Returns:
    #         spk_embedding

    #     """

    #     assert check_argument_types()

    #     # Input as audio signal
    #     if isinstance(speech, np.ndarray):
    #         speech = torch.tensor(speech)

    #     #(batch_size, 1, num_samples) - > (batch_size, num_samples)

    #     #remove the channel dimension
    #     if speech.dim() == 3 and speech.size(1) == 1:
    #         speech = speech.squeeze(1)

    #     print("speech size: ", speech.size())
        


    #     # data: (Nsamples,) -> (1, Nsamples)
    #     # speech = speech.unsqueeze(0).to(getattr(torch, self.dtype))
    #     logging.info("speech length: " + str(speech.size(1)))
    #     batch = {"speech": speech, "extract_embd": True}

    #     # a. To device
    #     batch = to_device(batch, device=self.device)

    #     print("calling speaker model with",speech.size())
    #     print("weights size",weights.size())

    #     # b. Forward the model embedding extraction
    #     output = self.spk_model(**batch)

    #     # features = rearrange(
    #     #     features,
    #     #     "batch dimension channel frames -> batch (dimension channel) frames",
    #     # )

    #     print("output size: ", output.size())
       
    #     # sequences : (batch, features, frames) torch.Tensor
    #     #     Sequences of features.
    #     # weights : (batch, frames) or (batch, speakers, frames) torch.Tensor, optional
    #     #pooling expect batch, features, frames
    #     # return self.stats_pool(output, weights=weights)

    #     #we already have pooling impllemented inside the model https://github.com/espnet/espnet/blob/master/espnet2/spk/pooling/chn_attn_stat_pooling.py
    #     #id doesent use the weights paramether thoe 


        

    #     return output
    
    @torch.no_grad()
    #NOTE TAKEN FROM NEMO SPEAKER VERIFICATION THIS JUST SETS WAFEROM SIGNAL THAT IS MASKED TO 0 AND CREATE SIGNALS 
    # THERE IS A BETTER WAY  it hink BY USING SPEECH LENGHTS BUT OK
    def __call__(self, speech: Union[torch.Tensor, np.ndarray], masks: torch.Tensor = None) -> torch.Tensor:
        """Inference

        Args:
            speech: Input speech data
            weights: Input weights
            masks: Optional masks for the input speech

        Returns:
            spk_embedding: Speaker embeddings
        """

        # Input as audio signal
        if isinstance(speech, np.ndarray):
            speech = torch.tensor(speech)

        # Remove the channel dimension if present
        if speech.dim() == 3 and speech.size(1) == 1:
            speech = speech.squeeze(1)

        # print("speech size: ", speech.size())

        batch_size, num_samples = speech.shape

        if masks is None:
            signals = speech
            wav_lens = signals.shape[1] * torch.ones(batch_size)
        else:
            batch_size_masks, _ = masks.shape
            assert batch_size == batch_size_masks

            imasks = F.interpolate(
                masks.unsqueeze(dim=1), size=num_samples, mode="nearest"
            ).squeeze(dim=1)

            imasks = imasks > 0.5

            signals = pad_sequence(
                [waveform[imask] for waveform, imask in zip(speech, imasks)],
                batch_first=True,
            )

            wav_lens = imasks.sum(dim=1)

        max_len = wav_lens.max()

        # Corner case: every signal is too short
        if max_len < self.min_num_samples:
            return np.NAN * np.zeros((batch_size, self.dimension))

        too_short = wav_lens < self.min_num_samples
        wav_lens[too_short] = max_len

        # Prepare batch for the model
        batch = {"speech": signals.to(self.device), "extract_embd": True}
        # print("calling speaker model with", signals.size())
        # print("weights size", masks.size())

        # Forward the model for embedding extraction
        output = self.spk_model(**batch)

        embeddings = output.cpu().numpy()
        embeddings[too_short.cpu().numpy()] = np.NAN

        return embeddings
    def to(self, device: torch.device, dtype: str = "float32"):
        """Move the model to the device

        Args:
            device: torch.device
            dtype: str

        Returns:
            self

        """
        print("moving model to device",device)
        self.spk_model.to(device)
        self.device = device
        self.dtype = dtype
        return self

    @staticmethod
    def from_pretrained(
        model_tag: Optional[str] = None,
        **kwargs: Optional[Any],
    ):
        """Build Speech2Embedding instance from the pretrained model.

        Args:
            model_tag (Optional[str]): Model tag of the pretrained models.
                Currently, the tags of espnet_model_zoo are supported.

        Returns:
            Speech2Text: Speech2Embedding instance.

        """
        if model_tag is not None:
            try:
                from espnet_model_zoo.downloader import ModelDownloader

            except ImportError:
                logging.error(
                    "`espnet_model_zoo` is not installed. "
                    "Please install via `pip install -U espnet_model_zoo`."
                )
                raise
            d = ModelDownloader()
            kwargs.update(**d.download_and_unpack(model_tag))

        return Speech2Embedding(**kwargs)


class ESPnetSPKSpeakerEmbedding(BaseInference):
    """Pretrained pyannote.audio speaker embedding

    Parameters
    ----------
    embedding : PipelineModel
        pyannote.audio model
    device : torch.device, optional
        Device
    use_auth_token : str, optional
        When loading private huggingface.co models, set `use_auth_token`
        to True or to a string containing your hugginface.co authentication
        token that can be obtained by running `huggingface-cli login`

    Usage
    -----
    >>> get_embedding = PyannoteAudioPretrainedSpeakerEmbedding("pyannote/embedding")
    >>> assert waveforms.ndim == 3
    >>> batch_size, num_channels, num_samples = waveforms.shape
    >>> assert num_channels == 1
    >>> embeddings = get_embedding(waveforms)
    >>> assert embeddings.ndim == 2
    >>> assert embeddings.shape[0] == batch_size

    >>> assert masks.ndim == 1
    >>> assert masks.shape[0] == batch_size
    >>> embeddings = get_embedding(waveforms, masks=masks)
    """

    def __init__(
        self,
        embedding: str = "espnet/voxcelebs12_rawnet3",
        device: Optional[torch.device] = None,
    ):
        super().__init__()
        self.embedding = embedding
        self.device = "mps"

        #transform device to string allow

        # self.model_: Model = get_model(self.embedding, use_auth_token=use_auth_token)
        self.model_ = Speech2Embedding.from_pretrained(model_tag="espnet/voxcelebs12_rawnet3",device=self.device,min_num_samples=self.min_num_samples,dimension=self.dimension)
        # self.model_.eval()
        # self.model_.to(self.device)

    def to(self, device: torch.device):
        if not isinstance(device, torch.device):
            raise TypeError(
                f"`device` must be an instance of `torch.device`, got `{type(device).__name__}`"
            )

        self.model_.to(device)
        self.device = device
        return self

    @cached_property
    def sample_rate(self) -> int:
        return 16000

    @cached_property
    def dimension(self) -> int:
        return 192

    @cached_property
    def metric(self) -> str:
        return "cosine"

    @cached_property
    def min_num_samples(self) -> int:
        with torch.inference_mode():
            lower, upper = 2, round(0.5 * self.sample_rate)
            middle = (lower + upper) // 2
            while lower + 1 < upper:
                try:
                    _ = self.model_(torch.randn(1, 1, middle).to(self.device))
                    upper = middle
                except Exception:
                    lower = middle

                middle = (lower + upper) // 2

        return upper

    def __call__(
        self, waveforms: torch.Tensor, masks: Optional[torch.Tensor] = None
    ) -> np.ndarray:
        with torch.inference_mode():
            if masks is None:
                embeddings = self.model_(waveforms.to(self.device))
            else:
                with warnings.catch_warnings():
                    warnings.simplefilter("ignore")
                    embeddings = self.model_(
                        waveforms.to(self.device), masks=masks.to(self.device)
                    )
        return embeddings

def PretrainedSpeakerEmbedding(
    embedding: PipelineModel,
    device: Optional[torch.device] = None,
    use_auth_token: Union[Text, None] = None,
):
    """Pretrained speaker embedding

    Parameters
    ----------
    embedding : Text
        Can be a SpeechBrain (e.g. "speechbrain/spkrec-ecapa-voxceleb")
        or a pyannote.audio model.
    device : torch.device, optional
        Device
    use_auth_token : str, optional
        When loading private huggingface.co models, set `use_auth_token`
        to True or to a string containing your hugginface.co authentication
        token that can be obtained by running `huggingface-cli login`

    Usage
    -----
    >>> get_embedding = PretrainedSpeakerEmbedding("pyannote/embedding")
    >>> get_embedding = PretrainedSpeakerEmbedding("speechbrain/spkrec-ecapa-voxceleb")
    >>> get_embedding = PretrainedSpeakerEmbedding("nvidia/speakerverification_en_titanet_large")
    >>> assert waveforms.ndim == 3
    >>> batch_size, num_channels, num_samples = waveforms.shape
    >>> assert num_channels == 1
    >>> embeddings = get_embedding(waveforms)
    >>> assert embeddings.ndim == 2
    >>> assert embeddings.shape[0] == batch_size

    >>> assert masks.ndim == 1
    >>> assert masks.shape[0] == batch_size
    >>> embeddings = get_embedding(waveforms, masks=masks)
    """

    if isinstance(embedding, str) and "pyannote" in embedding:
        return PyannoteAudioPretrainedSpeakerEmbedding(
            embedding, device=device, use_auth_token=use_auth_token
        )

    elif isinstance(embedding, str) and "speechbrain" in embedding:
        return SpeechBrainPretrainedSpeakerEmbedding(
            embedding, device=device, use_auth_token=use_auth_token
        )

    elif isinstance(embedding, str) and "nvidia" in embedding:
        return NeMoPretrainedSpeakerEmbedding(embedding, device=device)

    elif isinstance(embedding, str) and "wespeaker" in embedding:
        return ONNXWeSpeakerPretrainedSpeakerEmbedding(embedding, device=device)
    
    elif isinstance(embedding, str) and "espnet" in embedding:
        return ESPnetSPKSpeakerEmbedding(embedding, device=device)

    else:
        # fallback to pyannote in case we are loading a local model
        return PyannoteAudioPretrainedSpeakerEmbedding(
            embedding, device=device, use_auth_token=use_auth_token
        )


class SpeakerEmbedding(Pipeline):
    """Speaker embedding pipeline

    This pipeline assumes that each file contains exactly one speaker
    and extracts one single embedding from the whole file.

    Parameters
    ----------
    embedding : Model, str, or dict, optional
        Pretrained embedding model. Defaults to "pyannote/embedding".
        See pyannote.audio.pipelines.utils.get_model for supported format.
    segmentation : Model, str, or dict, optional
        Pretrained segmentation (or voice activity detection) model.
        See pyannote.audio.pipelines.utils.get_model for supported format.
        Defaults to no voice activity detection.
    use_auth_token : str, optional
        When loading private huggingface.co models, set `use_auth_token`
        to True or to a string containing your hugginface.co authentication
        token that can be obtained by running `huggingface-cli login`

    Usage
    -----
    >>> from pyannote.audio.pipelines import SpeakerEmbedding
    >>> pipeline = SpeakerEmbedding()
    >>> emb1 = pipeline("speaker1.wav")
    >>> emb2 = pipeline("speaker2.wav")
    >>> from scipy.spatial.distance import cdist
    >>> distance = cdist(emb1, emb2, metric="cosine")[0,0]
    """

    def __init__(
        self,
        embedding: PipelineModel = "pyannote/embedding",
        segmentation: Optional[PipelineModel] = None,
        use_auth_token: Union[Text, None] = None,
    ):
        super().__init__()

        self.embedding = embedding
        self.segmentation = segmentation

        self.embedding_model_: Model = get_model(
            embedding, use_auth_token=use_auth_token
        )

        if self.segmentation is not None:
            segmentation_model: Model = get_model(
                self.segmentation, use_auth_token=use_auth_token
            )
            self._segmentation = Inference(
                segmentation_model,
                pre_aggregation_hook=lambda scores: np.max(
                    scores, axis=-1, keepdims=True
                ),
            )

    def apply(self, file: AudioFile) -> np.ndarray:
        device = self.embedding_model_.device

        # read audio file and send it to GPU
        waveform = self.embedding_model_.audio(file)[0][None].to(device)

        if self.segmentation is None:
            weights = None
        else:
            # obtain voice activity scores
            weights = self._segmentation(file).data
            # HACK -- this should be fixed upstream
            weights[np.isnan(weights)] = 0.0
            weights = torch.from_numpy(weights**3)[None, :, 0].to(device)

        # extract speaker embedding on parts of
        with torch.no_grad():
            return self.embedding_model_(waveform, weights=weights).cpu().numpy()


def main(
    protocol: str = "VoxCeleb.SpeakerVerification.VoxCeleb1",
    subset: str = "test",
    embedding: str = "pyannote/embedding",
    segmentation: Optional[str] = None,
):
    import typer
    from pyannote.database import FileFinder, get_protocol
    from pyannote.metrics.binary_classification import det_curve
    from scipy.spatial.distance import cdist
    from tqdm import tqdm

    pipeline = SpeakerEmbedding(embedding=embedding, segmentation=segmentation)

    protocol = get_protocol(protocol, preprocessors={"audio": FileFinder()})

    y_true, y_pred = [], []

    emb = dict()

    trials = getattr(protocol, f"{subset}_trial")()

    for t, trial in enumerate(tqdm(trials)):
        audio1 = trial["file1"]["audio"]
        if audio1 not in emb:
            emb[audio1] = pipeline(audio1)

        audio2 = trial["file2"]["audio"]
        if audio2 not in emb:
            emb[audio2] = pipeline(audio2)

        y_pred.append(cdist(emb[audio1], emb[audio2], metric="cosine")[0][0])
        y_true.append(trial["reference"])

    _, _, _, eer = det_curve(y_true, np.array(y_pred), distances=True)
    typer.echo(
        f"{protocol.name} | {subset} | {embedding} | {segmentation} | EER = {100 * eer:.3f}%"
    )





In [None]:
# pipeline = PretrainedSpeakerEmbedding("espnet/voxcelebs12_rawnet3", device="mps")
# emb1 = pipeline(torch.randn(66, 1, 16000))

# print(emb1.shape)



In [None]:
# The MIT License (MIT)
#
# Copyright (c) 2021- CNRS
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

"""Speaker diarization pipelines"""

import functools
import itertools
import math
import textwrap
import warnings
from typing import Callable, Optional, Text, Union

import numpy as np
import torch
from einops import rearrange
from pyannote.core import Annotation, SlidingWindowFeature
from pyannote.metrics.diarization import GreedyDiarizationErrorRate
from pyannote.pipeline.parameter import ParamDict, Uniform

from pyannote.audio import Audio, Inference, Model, Pipeline
from pyannote.audio.core.io import AudioFile
from pyannote.audio.pipelines.clustering import Clustering
from pyannote.audio.pipelines.utils import (
    PipelineModel,
    SpeakerDiarizationMixin,
    get_model,
)
from pyannote.audio.utils.signal import binarize


def batchify(iterable, batch_size: int = 32, fillvalue=None):
    """Batchify iterable"""
    # batchify('ABCDEFG', 3) --> ['A', 'B', 'C']  ['D', 'E', 'F']  [G, ]
    args = [iter(iterable)] * batch_size
    return itertools.zip_longest(*args, fillvalue=fillvalue)


class SpeakerDiarization(SpeakerDiarizationMixin, Pipeline):
    """Speaker diarization pipeline

    Parameters
    ----------
    segmentation : Model, str, or dict, optional
        Pretrained segmentation model. Defaults to "pyannote/segmentation@2022.07".
        See pyannote.audio.pipelines.utils.get_model for supported format.
    segmentation_step: float, optional
        The segmentation model is applied on a window sliding over the whole audio file.
        `segmentation_step` controls the step of this window, provided as a ratio of its
        duration. Defaults to 0.1 (i.e. 90% overlap between two consecuive windows).
    embedding : Model, str, or dict, optional
        Pretrained embedding model. Defaults to "pyannote/embedding@2022.07".
        See pyannote.audio.pipelines.utils.get_model for supported format.
    embedding_exclude_overlap : bool, optional
        Exclude overlapping speech regions when extracting embeddings.
        Defaults (False) to use the whole speech.
    clustering : str, optional
        Clustering algorithm. See pyannote.audio.pipelines.clustering.Clustering
        for available options. Defaults to "AgglomerativeClustering".
    segmentation_batch_size : int, optional
        Batch size used for speaker segmentation. Defaults to 1.
    embedding_batch_size : int, optional
        Batch size used for speaker embedding. Defaults to 1.
    der_variant : dict, optional
        Optimize for a variant of diarization error rate.
        Defaults to {"collar": 0.0, "skip_overlap": False}. This is used in `get_metric`
        when instantiating the metric: GreedyDiarizationErrorRate(**der_variant).
    use_auth_token : str, optional
        When loading private huggingface.co models, set `use_auth_token`
        to True or to a string containing your hugginface.co authentication
        token that can be obtained by running `huggingface-cli login`

    Usage
    -----
    # perform (unconstrained) diarization
    >>> diarization = pipeline("/path/to/audio.wav")

    # perform diarization, targetting exactly 4 speakers
    >>> diarization = pipeline("/path/to/audio.wav", num_speakers=4)

    # perform diarization, with at least 2 speakers and at most 10 speakers
    >>> diarization = pipeline("/path/to/audio.wav", min_speakers=2, max_speakers=10)

    # perform diarization and get one representative embedding per speaker
    >>> diarization, embeddings = pipeline("/path/to/audio.wav", return_embeddings=True)
    >>> for s, speaker in enumerate(diarization.labels()):
    ...     # embeddings[s] is the embedding of speaker `speaker`

    Hyper-parameters
    ----------------
    segmentation.threshold
    segmentation.min_duration_off
    clustering.???
    """

    def __init__(
        self,
        segmentation: PipelineModel = "pyannote/segmentation@2022.07",
        segmentation_step: float = 0.1,
        embedding: PipelineModel = "speechbrain/spkrec-ecapa-voxceleb@5c0be3875fda05e81f3c004ed8c7c06be308de1e",
        embedding_exclude_overlap: bool = False,
        clustering: str = "AgglomerativeClustering",
        embedding_batch_size: int = 1,
        segmentation_batch_size: int = 1,
        der_variant: Optional[dict] = None,
        use_auth_token: Union[Text, None] = None,
    ):
        super().__init__()

        self.segmentation_model = segmentation
        model: Model = get_model(segmentation, use_auth_token=use_auth_token)

        self.segmentation_step = segmentation_step

        print("segmentation_step", segmentation_step)

        self.embedding = embedding
        self.embedding_batch_size = embedding_batch_size
        self.embedding_exclude_overlap = embedding_exclude_overlap

        print("embedding_batch_size", embedding_batch_size)

        self.klustering = clustering

        self.der_variant = der_variant or {"collar": 0.0, "skip_overlap": False}

        segmentation_duration = model.specifications.duration

        print("segmentation_duration", segmentation_duration)


        self._segmentation = Inference(
            model,
            duration=segmentation_duration,
            step=self.segmentation_step * segmentation_duration,
            skip_aggregation=True,
            batch_size=segmentation_batch_size,
        )

        if self._segmentation.model.specifications.powerset:
            self.segmentation = ParamDict(
                min_duration_off=Uniform(0.0, 1.0),
            )

        else:
            self.segmentation = ParamDict(
                threshold=Uniform(0.1, 0.9),
                min_duration_off=Uniform(0.0, 1.0),
            )

        if self.klustering == "OracleClustering":
            metric = "not_applicable"

        else:
            self._embedding = PretrainedSpeakerEmbedding(
                self.embedding, use_auth_token=use_auth_token
            )
            self._audio = Audio(sample_rate=self._embedding.sample_rate, mono="downmix")
            metric = self._embedding.metric

        try:
            Klustering = Clustering[clustering]
        except KeyError:
            raise ValueError(
                f'clustering must be one of [{", ".join(list(Clustering.__members__))}]'
            )
        self.clustering = Klustering.value(metric=metric)

    @property
    def segmentation_batch_size(self) -> int:
        return self._segmentation.batch_size

    @segmentation_batch_size.setter
    def segmentation_batch_size(self, batch_size: int):
        self._segmentation.batch_size = batch_size

    def default_parameters(self):
        raise NotImplementedError()

    def classes(self):
        speaker = 0
        while True:
            yield f"SPEAKER_{speaker:02d}"
            speaker += 1

    @property
    def CACHED_SEGMENTATION(self):
        return "training_cache/segmentation"

    def get_segmentations(self, file, hook=None) -> SlidingWindowFeature:
        """Apply segmentation model

        Parameter
        ---------
        file : AudioFile
        hook : Optional[Callable]

        Returns
        -------
        segmentations : (num_chunks, num_frames, num_speakers) SlidingWindowFeature
        """

        if hook is not None:
            hook = functools.partial(hook, "segmentation", None)

        if self.training:
            if self.CACHED_SEGMENTATION in file:
                segmentations = file[self.CACHED_SEGMENTATION]
            else:
                segmentations = self._segmentation(file, hook=hook)
                file[self.CACHED_SEGMENTATION] = segmentations
        else:
            segmentations: SlidingWindowFeature = self._segmentation(file, hook=hook)

        return segmentations

    def get_embeddings(
        self,
        file,
        binary_segmentations: SlidingWindowFeature,
        exclude_overlap: bool = False,
        hook: Optional[Callable] = None,
    ):
        """Extract embeddings for each (chunk, speaker) pair

        Parameters
        ----------
        file : AudioFile
        binary_segmentations : (num_chunks, num_frames, num_speakers) SlidingWindowFeature
            Binarized segmentation.
        exclude_overlap : bool, optional
            Exclude overlapping speech regions when extracting embeddings.
            In case non-overlapping speech is too short, use the whole speech.
        hook: Optional[Callable]
            Called during embeddings after every batch to report the progress

        Returns
        -------
        embeddings : (num_chunks, num_speakers, dimension) array
        """

        # when optimizing the hyper-parameters of this pipeline with frozen
        # "segmentation.threshold", one can reuse the embeddings from the first trial,
        # bringing a massive speed up to the optimization process (and hence allowing to use
        # a larger search space).
        if self.training:
            # we only re-use embeddings if they were extracted based on the same value of the
            # "segmentation.threshold" hyperparameter or if the segmentation model relies on
            # `powerset` mode
            cache = file.get("training_cache/embeddings", dict())
            if ("embeddings" in cache) and (
                self._segmentation.model.specifications.powerset
                or (cache["segmentation.threshold"] == self.segmentation.threshold)
            ):
                return cache["embeddings"]

        duration = binary_segmentations.sliding_window.duration
        num_chunks, num_frames, num_speakers = binary_segmentations.data.shape

        if exclude_overlap:
            # minimum number of samples needed to extract an embedding
            # (a lower number of samples would result in an error)
            min_num_samples = self._embedding.min_num_samples

            # corresponding minimum number of frames
            num_samples = duration * self._embedding.sample_rate
            min_num_frames = math.ceil(num_frames * min_num_samples / num_samples)

            # zero-out frames with overlapping speech
            clean_frames = 1.0 * (
                np.sum(binary_segmentations.data, axis=2, keepdims=True) < 2
            )
            clean_segmentations = SlidingWindowFeature(
                binary_segmentations.data * clean_frames,
                binary_segmentations.sliding_window,
            )

        else:
            min_num_frames = -1
            clean_segmentations = SlidingWindowFeature(
                binary_segmentations.data, binary_segmentations.sliding_window
            )

        def iter_waveform_and_mask():
            for (chunk, masks), (_, clean_masks) in zip(
                binary_segmentations, clean_segmentations
            ):
                # chunk: Segment(t, t + duration)
                # masks: (num_frames, local_num_speakers) np.ndarray

                waveform, _ = self._audio.crop(
                    file,
                    chunk,
                    duration=duration,
                    mode="pad",
                )
                # waveform: (1, num_samples) torch.Tensor

                # mask may contain NaN (in case of partial stitching)
                masks = np.nan_to_num(masks, nan=0.0).astype(np.float32)
                clean_masks = np.nan_to_num(clean_masks, nan=0.0).astype(np.float32)

                for mask, clean_mask in zip(masks.T, clean_masks.T):
                    # mask: (num_frames, ) np.ndarray

                    if np.sum(clean_mask) > min_num_frames:
                        used_mask = clean_mask
                    else:
                        used_mask = mask

                    yield waveform[None], torch.from_numpy(used_mask)[None]
                    # w: (1, 1, num_samples) torch.Tensor
                    # m: (1, num_frames) torch.Tensor

        batches = batchify(
            iter_waveform_and_mask(),
            batch_size=self.embedding_batch_size,
            fillvalue=(None, None),
        )

        batch_count = math.ceil(num_chunks * num_speakers / self.embedding_batch_size)

        embedding_batches = []

        if hook is not None:
            hook("embeddings", None, total=batch_count, completed=0)

        for i, batch in enumerate(batches, 1):
            waveforms, masks = zip(*filter(lambda b: b[0] is not None, batch))

            waveform_batch = torch.vstack(waveforms)
            # (batch_size, 1, num_samples) torch.Tensor

            mask_batch = torch.vstack(masks)
            # (batch_size, num_frames) torch.Tensor

            embedding_batch: np.ndarray = self._embedding(
                waveform_batch, masks=mask_batch
            )
            # (batch_size, dimension) np.ndarray

            embedding_batches.append(embedding_batch)

            if hook is not None:
                hook("embeddings", embedding_batch, total=batch_count, completed=i)

        embedding_batches = np.vstack(embedding_batches)

        embeddings = rearrange(embedding_batches, "(c s) d -> c s d", c=num_chunks)

        print("embeddings.shape", embeddings.shape,"num_chunks", num_chunks)

        # caching embeddings for subsequent trials
        # (see comments at the top of this method for more details)
        if self.training:
            if self._segmentation.model.specifications.powerset:
                file["training_cache/embeddings"] = {
                    "embeddings": embeddings,
                }
            else:
                file["training_cache/embeddings"] = {
                    "segmentation.threshold": self.segmentation.threshold,
                    "embeddings": embeddings,
                }

        return embeddings

    def reconstruct(
        self,
        segmentations: SlidingWindowFeature,
        hard_clusters: np.ndarray,
        count: SlidingWindowFeature,
    ) -> SlidingWindowFeature:
        """Build final discrete diarization out of clustered segmentation

        Parameters
        ----------
        segmentations : (num_chunks, num_frames, num_speakers) SlidingWindowFeature
            Raw speaker segmentation.
        hard_clusters : (num_chunks, num_speakers) array
            Output of clustering step.
        count : (total_num_frames, 1) SlidingWindowFeature
            Instantaneous number of active speakers.

        Returns
        -------
        discrete_diarization : SlidingWindowFeature
            Discrete (0s and 1s) diarization.
        """

        num_chunks, num_frames, local_num_speakers = segmentations.data.shape

        num_clusters = np.max(hard_clusters) + 1
        clustered_segmentations = np.NAN * np.zeros(
            (num_chunks, num_frames, num_clusters)
        )

        for c, (cluster, (chunk, segmentation)) in enumerate(
            zip(hard_clusters, segmentations)
        ):
            # cluster is (local_num_speakers, )-shaped
            # segmentation is (num_frames, local_num_speakers)-shaped
            for k in np.unique(cluster):
                if k == -2:
                    continue

                # TODO: can we do better than this max here?
                clustered_segmentations[c, :, k] = np.max(
                    segmentation[:, cluster == k], axis=1
                )

        clustered_segmentations = SlidingWindowFeature(
            clustered_segmentations, segmentations.sliding_window
        )

        return self.to_diarization(clustered_segmentations, count)

    def apply(
        self,
        file: AudioFile,
        num_speakers: Optional[int] = None,
        min_speakers: Optional[int] = None,
        max_speakers: Optional[int] = None,
        return_embeddings: bool = False,
        hook: Optional[Callable] = None,
    ) -> Annotation:
        """Apply speaker diarization

        Parameters
        ----------
        file : AudioFile
            Processed file.
        num_speakers : int, optional
            Number of speakers, when known.
        min_speakers : int, optional
            Minimum number of speakers. Has no effect when `num_speakers` is provided.
        max_speakers : int, optional
            Maximum number of speakers. Has no effect when `num_speakers` is provided.
        return_embeddings : bool, optional
            Return representative speaker embeddings.
        hook : callable, optional
            Callback called after each major steps of the pipeline as follows:
                hook(step_name,      # human-readable name of current step
                     step_artefact,  # artifact generated by current step
                     file=file)      # file being processed
            Time-consuming steps call `hook` multiple times with the same `step_name`
            and additional `completed` and `total` keyword arguments usable to track
            progress of current step.

        Returns
        -------
        diarization : Annotation
            Speaker diarization
        embeddings : np.array, optional
            Representative speaker embeddings such that `embeddings[i]` is the
            speaker embedding for i-th speaker in diarization.labels().
            Only returned when `return_embeddings` is True.
        """

        # setup hook (e.g. for debugging purposes)
        hook = self.setup_hook(file, hook=hook)

        num_speakers, min_speakers, max_speakers = self.set_num_speakers(
            num_speakers=num_speakers,
            min_speakers=min_speakers,
            max_speakers=max_speakers,
        )

        segmentations = self.get_segmentations(file, hook=hook)
        hook("segmentation", segmentations)
        #   shape: (num_chunks, num_frames, local_num_speakers)
        num_chunks, num_frames, local_num_speakers = segmentations.data.shape

        print("num_chunks", num_chunks)

        # binarize segmentation
        if self._segmentation.model.specifications.powerset:
            binarized_segmentations = segmentations
        else:
            binarized_segmentations: SlidingWindowFeature = binarize(
                segmentations,
                onset=self.segmentation.threshold,
                initial_state=False,
            )

        hook("binarized_segmentation", binarized_segmentations)
        # estimate frame-level number of instantaneous speakers
        count = self.speaker_count(
            binarized_segmentations,
            self._segmentation.model.receptive_field,
            warm_up=(0.0, 0.0),
        )
        hook("speaker_counting", count)
        #   shape: (num_frames, 1)
        #   dtype: int

        # exit early when no speaker is ever active
        if np.nanmax(count.data) == 0.0:
            diarization = Annotation(uri=file["uri"])
            if return_embeddings:
                return diarization, np.zeros((0, self._embedding.dimension))

            return diarization

        # skip speaker embedding extraction and clustering when only one speaker
        if not return_embeddings and max_speakers < 2:
            hard_clusters = np.zeros((num_chunks, local_num_speakers), dtype=np.int8)
            embeddings = None
            centroids = None

        else:

            # skip speaker embedding extraction with oracle clustering
            if self.klustering == "OracleClustering" and not return_embeddings:
                embeddings = None

            else:
                embeddings = self.get_embeddings(
                    file,
                    binarized_segmentations,
                    exclude_overlap=self.embedding_exclude_overlap,
                    hook=hook,
                )
                hook("embeddings", embeddings)
                #   shape: (num_chunks, local_num_speakers, dimension)

            hard_clusters, _, centroids = self.clustering(
                embeddings=embeddings,
                segmentations=binarized_segmentations,
                num_clusters=num_speakers,
                min_clusters=min_speakers,
                max_clusters=max_speakers,
                file=file,  # <== for oracle clustering
                frames=self._segmentation.model.receptive_field,  # <== for oracle clustering
            )
            # hook("hard_clusters", hard_clusters)
            # hard_clusters: (num_chunks, num_speakers)
            # centroids: (num_speakers, dimension)

        # number of detected clusters is the number of different speakers
        num_different_speakers = np.max(hard_clusters) + 1

        # detected number of speakers can still be out of bounds
        # (specifically, lower than `min_speakers`), since there could be too few embeddings
        # to make enough clusters with a given minimum cluster size.
        if (
            num_different_speakers < min_speakers
            or num_different_speakers > max_speakers
        ):
            warnings.warn(
                textwrap.dedent(
                    f"""
                The detected number of speakers ({num_different_speakers}) is outside
                the given bounds [{min_speakers}, {max_speakers}]. This can happen if the
                given audio file is too short to contain {min_speakers} or more speakers.
                Try to lower the desired minimal number of speakers.
                """
                )
            )

        # during counting, we could possibly overcount the number of instantaneous
        # speakers due to segmentation errors, so we cap the maximum instantaneous number
        # of speakers by the `max_speakers` value
        count.data = np.minimum(count.data, max_speakers).astype(np.int8)

        # reconstruct discrete diarization from raw hard clusters

        # keep track of inactive speakers
        inactive_speakers = np.sum(binarized_segmentations.data, axis=1) == 0
        #   shape: (num_chunks, num_speakers)

        hard_clusters[inactive_speakers] = -2

        hook("hard_clusters", hard_clusters)

        discrete_diarization = self.reconstruct(
            segmentations,
            hard_clusters,
            count,
        )
        hook("discrete_diarization", discrete_diarization)

        # convert to continuous diarization
        diarization = self.to_annotation(
            discrete_diarization,
            min_duration_on=0.0,
            min_duration_off=self.segmentation.min_duration_off,
        )
        diarization.uri = file["uri"]

        # at this point, `diarization` speaker labels are integers
        # from 0 to `num_speakers - 1`, aligned with `centroids` rows.

        if "annotation" in file and file["annotation"]:
            # when reference is available, use it to map hypothesized speakers
            # to reference speakers (this makes later error analysis easier
            # but does not modify the actual output of the diarization pipeline)
            _, mapping = self.optimal_mapping(
                file["annotation"], diarization, return_mapping=True
            )

            # in case there are more speakers in the hypothesis than in
            # the reference, those extra speakers are missing from `mapping`.
            # we add them back here
            mapping = {key: mapping.get(key, key) for key in diarization.labels()}

        else:
            # when reference is not available, rename hypothesized speakers
            # to human-readable SPEAKER_00, SPEAKER_01, ...
            mapping = {
                label: expected_label
                for label, expected_label in zip(diarization.labels(), self.classes())
            }

        diarization = diarization.rename_labels(mapping=mapping)

        # at this point, `diarization` speaker labels are strings (or mix of
        # strings and integers when reference is available and some hypothesis
        # speakers are not present in the reference)

        if not return_embeddings:
            return diarization

        # this can happen when we use OracleClustering
        if centroids is None:
            return diarization, None

        # The number of centroids may be smaller than the number of speakers
        # in the annotation. This can happen if the number of active speakers
        # obtained from `speaker_count` for some frames is larger than the number
        # of clusters obtained from `clustering`. In this case, we append zero embeddings
        # for extra speakers
        if len(diarization.labels()) > centroids.shape[0]:
            centroids = np.pad(
                centroids, ((0, len(diarization.labels()) - centroids.shape[0]), (0, 0))
            )

        # re-order centroids so that they match
        # the order given by diarization.labels()
        inverse_mapping = {label: index for index, label in mapping.items()}
        centroids = centroids[
            [inverse_mapping[label] for label in diarization.labels()]
        ]

        return diarization, centroids

    def get_metric(self) -> GreedyDiarizationErrorRate:
        return GreedyDiarizationErrorRate(**self.der_variant)


In [None]:
#ALL EMBED TRY


import pickle
import einops
import torch
import numpy as np

import yaml
from pyannote.audio import Pipeline

from typing import Any, Mapping, Optional, Text
import torch
from copy import deepcopy
from rich.progress import Progress, TextColumn, BarColumn, TaskProgressColumn, TimeRemainingColumn
from pyannote.audio.pipelines.utils.hook import ArtifactHook,ProgressHook

class CombinedHook:
    """Composite Hook to save artifacts and show progress of each internal step.

    Parameters
    ----------
    artifacts: list of str, optional
        List of steps to save. Defaults to all steps.
    file_key: str, optional
        Key used to store artifacts in `file`.
        Defaults to "artifact".
    transient: bool, optional
        Clear the progress on exit. Defaults to False.

    Usage
    -----
    >>> with CombinedHook() as hook:
    ...     output = pipeline(file, hook=hook)
    # file["artifact"] contains a dict with artifacts of each step
    """

    def __init__(self, *artifacts, file_key: str = "artifact", transient: bool = False):
        self.artifact_hook = ArtifactHook(*artifacts, file_key=file_key)
        self.progress_hook = ProgressHook(transient=transient)

    def __enter__(self):
        self.artifact_hook.__enter__()
        self.progress_hook.__enter__()
        return self

    def __exit__(self, *args):
        self.artifact_hook.__exit__(*args)
        self.progress_hook.__exit__(*args)

    def __call__(
        self,
        step_name: Text,
        step_artifact: Any,
        file: Optional[Mapping] = None,
        total: Optional[int] = None,
        completed: Optional[int] = None,
    ):
        self.artifact_hook(step_name, step_artifact, file, total, completed)
        self.progress_hook(step_name, step_artifact, file, total, completed)


# #perform speaker diarization on full audio
    # pipeline = Pipeline.from_pretrained(
    # "./config2.yaml",
    #     use_auth_token="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")

#     pipeline:
#   name: pyannote.audio.pipelines.SpeakerDiarization
#   params:
#     clustering: AgglomerativeClustering
#     embedding: pyannote/embedding
#     embedding_batch_size: 32
#     embedding_exclude_overlap: true
#     segmentation: pyannote/segmentation-3.0
#     segmentation_batch_size: 32

# params:
#   clustering:
#     method: centroid
#     min_cluster_size: 12
#     threshold: 0.7045654963945799
#   segmentation:
#     min_duration_off: 0.0


    # pipeline = SpeakerDiarization(  
    #     segmentation =  "pyannote/segmentation@2022.07",
    #     segmentation_step= 0.1,
    #     embedding: PipelineModel = "speechbrain/spkrec-ecapa-voxceleb@5c0be3875fda05e81f3c004ed8c7c06be308de1e",
    #     embedding_exclude_overlap: bool = False,
    #     clustering: str = "AgglomerativeClustering",
    #     embedding_batch_size: int = 1,
    #     segmentation_batch_size: int = 1,
    #     der_variant: Optional[dict] = None,
    #     use_auth_token: Union[Text, None] = None,

pipeline = SpeakerDiarization(
    segmentation =  "pyannote/segmentation-3.0",
    segmentation_step= 0.1,
    embedding = "espnet/voxcelebs12_rawnet3",
    embedding_exclude_overlap = True,
    clustering = "AgglomerativeClustering",
    embedding_batch_size = 32,
    segmentation_batch_size = 32,
    der_variant = None,
    use_auth_token = "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"
)

pipeline.to(torch.device("mps"))

cc = """
params:
    clustering:
        method: centroid
        min_cluster_size: 12
        threshold: 0.7045654963945799
    segmentation:
        min_duration_off: 0.0
"""

#TODO FIND THE RIGHT SEGMENTATION THRESHOLD THAT IS DEFAULT!!!!!!!!  0.5 IS GUESEING 
#note segmentation.threshold not used model is powerset it dpesent mater since segmentation meodel is that one


cc = yaml.load(cc, Loader=yaml.FullLoader)

print(cc)

# pipeline.instantiate(config["params"])
pipeline.instantiate(cc["params"])




with open("outputttt3_final_output.pkl", "rb") as f:
    data = pickle.load(f)

print(data.shape)



all_full_embedings = []
all_hard_clusters = []
all_diars = []
all_centroids = []
for i in range(data.shape[0]):

    mic = data[i]


    sr = 16000

    print(mic.shape)

    #(30, 6, 160000) -> (6, 160000 * 30)

    mic = einops.rearrange(mic, "a b c -> b (a c)")

    print(mic.shape) 


    #(6, 4800000) -> (14800000 * 6)

    mic_full_flat = einops.rearrange(mic, "a b -> (a b)")

    print(mic_full_flat.shape) #(28800000,)


    # audio_in_memory = {"waveform": waveform, "sample_rate": sample_rate}
    # type(waveform)=<class 'torch.Tensor'>
    # waveform.shape=torch.Size([1, 480000])
    # waveform.dtype=torch.float32

    audio_in_memory = {"waveform": torch.from_numpy(mic_full_flat).unsqueeze(0), "sample_rate": 16000}

    

    # run the pipeline on an audio file


    with CombinedHook() as hook:

        diarization,embedings = pipeline(audio_in_memory,hook=hook,return_embeddings=True)


    print(audio_in_memory.keys())

    full_embedings = audio_in_memory["artifact"]["embeddings"] #(num_chunks, local_num_speakers, dimension)
    hard_clusters = audio_in_memory["artifact"]["hard_clusters"] #(num_chunks, local_num_speakers)

    all_full_embedings.append(full_embedings)
    all_hard_clusters.append(hard_clusters)
    all_diars.append(diarization)
    all_centroids.append(embedings)



In [None]:
all_max_amplitude = []

for i in range(data.shape[0]):
    mic = data[i]

    sr = 16000

    current_diar = all_diars[i]

    current_mic = mic

    current_mic = einops.rearrange(current_mic, "a b c -> b (a c)")

    current_mic = einops.rearrange(current_mic, "a b -> (a b)")

    

    current_speaker_max_amplitude = {}

    for turn, _, speaker in current_diar.itertracks(yield_label=True):
            
            start = int(turn.start * sr)
    
            end = int(turn.end * sr)
    
            segment = current_mic[start:end]
    
            segment = np.array(segment)
    
            max_amplitude = np.max(np.abs(segment))
    
            if speaker not in current_speaker_max_amplitude:
    
                current_speaker_max_amplitude[speaker] = 0
    
            current_speaker_max_amplitude[speaker] = max(max_amplitude,current_speaker_max_amplitude[speaker])

    current_speaker_max_amplitude = dict(sorted(current_speaker_max_amplitude.items(), key=lambda item: item[1], reverse=True))

    all_max_amplitude.append(current_speaker_max_amplitude)

print(all_max_amplitude)

#get centroids


#plot histrograms for each run 
import plotly.graph_objects as go

fig = go.Figure()

for i in range(len(all_max_amplitude)):
    current_max_amplitude = all_max_amplitude[i]
    for speaker, amplitude in current_max_amplitude.items():
        fig.add_trace(go.Bar(x=[
            f"mic_{i}_{speaker}"], y=[amplitude]))
        
fig.update_layout(yaxis_title="Amplitude", xaxis_title="Speaker", title="Speaker Amplitude")

fig.show()

    

In [None]:
#print most probable spekaer for each run based on amplitude

best_speakers_per_mic = []
for i in range(len(all_max_amplitude)):
    current_max_amplitude = all_max_amplitude[i]
    print(f"mic_{i} best speaker {list(current_max_amplitude.keys())[0]} with max amplitude {list(current_max_amplitude.values())[0]}")
    


    

In [None]:
#plot all centroids on umap

centroid_speaker_labels_run = []

runsss= []
sizes  = []



for i in range(len(all_centroids)):

    current_centroids = all_centroids[i]




    #sorted by speaker automatically
    for j in range(current_centroids.shape[0]):

        runsss.append(i)

        centroid_speaker_labels_run.append(f"mic_{i}_speaker_{j}")

        speaker_key = f"SPEAKER_{j:02d}"
        sizes.append(all_max_amplitude[i].get(speaker_key, 2222))  # Default size to 1 if key not found

        # sizes.append(all_max_amplitude[i][f"SPEAKER_0{j}"])

print(centroid_speaker_labels_run)

all_centroids_flat = np.concatenate(all_centroids,axis=0)

print(all_centroids_flat.shape)

import umap

reducer = umap.UMAP(metric="cosine",n_components=3)

umap_centroids = reducer.fit_transform(all_centroids_flat)

#plotly

import plotly.express as px

import pandas as pd

df = pd.DataFrame(umap_centroids, columns=["x", "y","z"])
#
df["speaker"] = centroid_speaker_labels_run

#to string 

df["run"] = runsss

df["run"] = df["run"].apply(lambda x: "run_"+str(x))    


#sizez proportianal to max amplitude of that speaket 

df["size"] = sizes



fig = px.scatter_3d(df, x="x", y="y", z="z", color="speaker",size="size",hover_data=["run"])



fig.show()


fig = px.scatter_3d(df, x="x", y="y", z="z", color="run")
fig.show()

print("vidi se posamezne speakerje in clustre ki so mixi vecih speakerjev naenkrat ko ne locimo dobro")



In [None]:
#plot all centroids on umap only mic x

mic_to_plot = 1

centroid_speaker_labels_run = []

runsss= []

sizes  = []

current_centroids = all_centroids[mic_to_plot]

for j in range(current_centroids.shape[0]):
    runsss.append(mic_to_plot)
    centroid_speaker_labels_run.append(f"mic_{mic_to_plot}_speaker_{j}")
    speaker_key = f"SPEAKER_{j:02d}"
    sizes.append(all_max_amplitude[mic_to_plot].get(speaker_key, None))  # Default size to 1 if key not found

print(centroid_speaker_labels_run)

all_centroids_flat = current_centroids

print(all_centroids_flat.shape)

import umap

reducer = umap.UMAP(metric="cosine",n_components=3)

umap_centroids = reducer.fit_transform(all_centroids_flat)

#plotly

import plotly.express as px

import pandas as pd

df = pd.DataFrame(umap_centroids, columns=["x", "y","z"])

df["speaker"] = centroid_speaker_labels_run

#to string

df["run"] = runsss

df["run"] = df["run"].apply(lambda x: "run_"+str(x))

#sizez proportianal to max amplitude of that speaket

df["size"] = sizes

fig = px.scatter_3d(df, x="x", y="y", z="z", color="speaker",size="size",hover_data=["run"])

fig.show()



In [None]:
import plotly.graph_objects as go
import plotly.express as px

fig = go.Figure()

# Collect all unique labels for the y-axis
unique_labels = []

for i in range(len(all_diars)):
    current_diar = all_diars[i]
    for speaker in current_diar.labels():
        unique_label = f"mic_{i}_{speaker}"
        unique_labels.append(unique_label)

# Remove duplicates and sort labels
unique_labels = sorted(set(unique_labels))

for i in range(len(all_diars)):
    current_diar = all_diars[i]

    #get representitive speakers for each run
    current_max_amplitude = all_max_amplitude[i]
    


    run_color = px.colors.qualitative.Plotly[i % len(px.colors.qualitative.Plotly)]

    for turn, _, speaker in current_diar.itertracks(yield_label=True):
        y_label = f"mic_{i}_{speaker}"

        #this is flat signal with stacked source separatin channels 

        # print(turn.start, turn.start - (300 * (turn.start // 300)))
        start = np.rint(turn.start % 300)
        end = np.rint(turn.end % 300)

        # Handle special cases
        if end < start and np.allclose(start, 300):
            start = 0

        if start > end and np.allclose(end, 0):
            end = 300

            
        # if y_label == "mic_2_SPEAKER_03":
        #     print(start,end,turn.start,turn.end)


        #calculate max amplitude of this segment

        segment = pipeline._audio.crop(segment=turn,file=audio_in_memory)
        segment=segment[0]

        segment = segment.squeeze(0).numpy()

        def calculate_energy(signal):
            return np.sum(signal ** 2)

        energy = calculate_energy(segment)

        #energy values are rea

        #use plotly colors scale to get color if representitive speaker

        #check if the speaker is the representitive speaker
        #fin the index if the curent speaker in current_max_amplitude
        speaker_index = list(current_max_amplitude.keys()).index(speaker)
        # print(speaker_index)

        #get max amplitude of the speaker

        #get color based on index  use plotly colors make the first one really hot
        c = px.colors.sequential.Reds[(8-speaker_index)]

        #max_amplitude
        # print(max_amplitude)

        

        # cc = int(energy*100000)
        # #max at 100
        # cc = min(cc,5)
        # print(energy)
        # c = px.colors.sequential.Reds[cc]

        fig.add_trace(go.Scatter(
            x=[start,end],
            y=[y_label, y_label],
            mode='lines',
            line=dict(color=c, width=10),
            name=y_label,
            legendgroup=f"mic_{i}",
            showlegend=(y_label not in [trace.name for trace in fig.data])
        ))

# Update layout to set y-axis as category type and use the unique labels
fig.update_layout(
    yaxis=dict(
        title='Speakers',
        tickmode='array',
        tickvals=unique_labels,
        ticktext=unique_labels,
        categoryorder='array',
        categoryarray=unique_labels
    ),
    xaxis=dict(title='Time'),
    title='Diarization Visualization',
    legend=dict(title='Speakers', itemsizing='constant')
)

fig.show(renderer="browser")

In [None]:
import plotly.graph_objects as go
import plotly.express as px

fig = go.Figure()

# Collect all unique labels for the y-axis
unique_labels = []


i = 0


current_diar = all_diars[i]
for speaker in current_diar.labels():
    unique_label = f"mic_{i}_{speaker}"
    unique_labels.append(unique_label)

# Remove duplicates and sort labels
unique_labels.append("GT")
unique_labels = sorted(set(unique_labels))

#add gt label
gt_Segments= [[66,70],[75,77],[113,118],[155,171],[171.5,172],[225,239],[249,254]]

for start,end in gt_Segments:
    fig.add_trace(go.Scatter(
        x=[start,end],
        y=["GT", "GT"],
        mode='lines',
        line=dict(color="green", width=10),
        name=y_label,
        legendgroup=f"mic_{i}",
        showlegend=(y_label not in [trace.name for trace in fig.data])
    ))

current_diar = all_diars[i]

#get representitive speakers for each run
current_max_amplitude = all_max_amplitude[i]



run_color = px.colors.qualitative.Plotly[i % len(px.colors.qualitative.Plotly)]

for turn, _, speaker in current_diar.itertracks(yield_label=True):
    y_label = f"mic_{i}_{speaker}"

    #this is flat signal with stacked source separatin channels 

    # print(turn.start, turn.start - (300 * (turn.start // 300)))
    start = np.rint(turn.start % 300)
    end = np.rint(turn.end % 300)

    # Handle special cases
    if end < start and np.allclose(start, 300):
        start = 0

    if start > end and np.allclose(end, 0):
        end = 300

        
    # if y_label == "mic_2_SPEAKER_03":
    #     print(start,end,turn.start,turn.end)


    #calculate max amplitude of this segment

    segment = pipeline._audio.crop(segment=turn,file=audio_in_memory)
    segment=segment[0]

    segment = segment.squeeze(0).numpy()

    def calculate_energy(signal):
        return np.sum(signal ** 2)

    energy = calculate_energy(segment)

    #energy values are rea

    #use plotly colors scale to get color if representitive speaker

    #check if the speaker is the representitive speaker
    #fin the index if the curent speaker in current_max_amplitude
    speaker_index = list(current_max_amplitude.keys()).index(speaker)
    # print(speaker_index)

    #get max amplitude of the speaker

    #get color based on index  use plotly colors make the first one really hot
    c = px.colors.sequential.Reds[(8-speaker_index)]

    #max_amplitude
    # print(max_amplitude)

    

    cc = int(energy*100000)
    #max at 100
    cc = min(cc,5)
    print(energy)
    c = px.colors.sequential.Reds[cc]

    fig.add_trace(go.Scatter(
        x=[start,end],
        y=[y_label, y_label],
        mode='lines',
        line=dict(color=c, width=10),
        name=y_label,
        legendgroup=f"mic_{i}",
        showlegend=(y_label not in [trace.name for trace in fig.data])
    ))



# Update layout to set y-axis as category type and use the unique labels
fig.update_layout(
yaxis=dict(
    title='Speakers',
    tickmode='array',
    tickvals=unique_labels,
    ticktext=unique_labels,
    categoryorder='array',
    categoryarray=unique_labels
),
xaxis=dict(title='Time'),
title='Diarization Visualization',
legend=dict(title='Speakers', itemsizing='constant')
)

fig.show(renderer="browser")

In [None]:
import plotly.graph_objects as go
import plotly.express as px

fig = go.Figure()

# Collect all unique labels for the y-axis
unique_labels = []


i = 1


current_diar = all_diars[i]
for speaker in current_diar.labels():
    unique_label = f"mic_{i}_{speaker}"
    unique_labels.append(unique_label)

# Remove duplicates and sort labels
unique_labels.append("GT")
unique_labels = sorted(set(unique_labels))

#add gt label
gt_Segments= [[0,9],[15,17.9],[18.2,20.4],[29.8,30.10],[39.5,39.8],[40.8,49.37],[52.10,52.5],[60.13,60.5],[63.30,64.3],[201.8,207.8],[212.4,212.6],[215.1,216.1],[223.3,225],[247.7,249.8]]
for start,end in gt_Segments:
    fig.add_trace(go.Scatter(
        x=[start,end],
        y=["GT", "GT"],
        mode='lines',
        line=dict(color="green", width=10),
        name=y_label,
        legendgroup=f"mic_{i}",
        showlegend=(y_label not in [trace.name for trace in fig.data])
    ))

current_diar = all_diars[i]

#get representitive speakers for each run
current_max_amplitude = all_max_amplitude[i]



run_color = px.colors.qualitative.Plotly[i % len(px.colors.qualitative.Plotly)]

for turn, _, speaker in current_diar.itertracks(yield_label=True):
    y_label = f"mic_{i}_{speaker}"

    #this is flat signal with stacked source separatin channels 

    # print(turn.start, turn.start - (300 * (turn.start // 300)))
    start = np.rint(turn.start % 300)
    end = np.rint(turn.end % 300)

    # Handle special cases
    if end < start and np.allclose(start, 300):
        start = 0

    if start > end and np.allclose(end, 0):
        end = 300

        
    # if y_label == "mic_2_SPEAKER_03":
    #     print(start,end,turn.start,turn.end)


    #calculate max amplitude of this segment

    segment = pipeline._audio.crop(segment=turn,file=audio_in_memory)
    segment=segment[0]

    segment = segment.squeeze(0).numpy()

    def calculate_energy(signal):
        return np.sum(signal ** 2)

    energy = calculate_energy(segment)

    #energy values are rea

    #use plotly colors scale to get color if representitive speaker

    #check if the speaker is the representitive speaker
    #fin the index if the curent speaker in current_max_amplitude
    speaker_index = list(current_max_amplitude.keys()).index(speaker)
    # print(speaker_index)

    #get max amplitude of the speaker

    #get color based on index  use plotly colors make the first one really hot
    c = px.colors.sequential.Reds[(8-speaker_index)]

    #max_amplitude
    # print(max_amplitude)

    

    # cc = int(energy*100000)
    # #max at 100
    # cc = min(cc,5)
    # print(energy)
    # c = px.colors.sequential.Reds[cc]

    fig.add_trace(go.Scatter(
        x=[start,end],
        y=[y_label, y_label],
        mode='lines',
        line=dict(color=c, width=10),
        name=y_label,
        legendgroup=f"mic_{i}",
        showlegend=(y_label not in [trace.name for trace in fig.data])
    ))



# Update layout to set y-axis as category type and use the unique labels
fig.update_layout(
yaxis=dict(
    title='Speakers',
    tickmode='array',
    tickvals=unique_labels,
    ticktext=unique_labels,
    categoryorder='array',
    categoryarray=unique_labels
),
xaxis=dict(title='Time'),
title='Diarization Visualization',
legend=dict(title='Speakers', itemsizing='constant')
)

fig.show(renderer="browser")

In [None]:
import plotly.graph_objects as go
import plotly.express as px

fig = go.Figure()

# Collect all unique labels for the y-axis
unique_labels = []


i = 5


current_diar = all_diars[i]
for speaker in current_diar.labels():
    unique_label = f"mic_{i}_{speaker}"
    unique_labels.append(unique_label)

# Remove duplicates and sort labels
unique_labels.append("GT")
unique_labels = sorted(set(unique_labels))



#add gt label
gt_Segments= [[10.7,14.4],[21.7,41.5],[48.4,63.2],[95.8,96.8],[102.2,102.8],[109.6,111.0],[114,118],[128,144],[155,155.8],[183,202.2],[261.3,262],[268.9,269],[274.2,278.7],[281.8,282]]
for start,end in gt_Segments:
    fig.add_trace(go.Scatter(
        x=[start,end],
        y=["GT", "GT"],
        mode='lines',
        line=dict(color="green", width=10),
        name=y_label,
        legendgroup=f"mic_{i}",
        showlegend=(y_label not in [trace.name for trace in fig.data])
    ))

current_diar = all_diars[i]

#get representitive speakers for each run
current_max_amplitude = all_max_amplitude[i]



run_color = px.colors.qualitative.Plotly[i % len(px.colors.qualitative.Plotly)]

for turn, _, speaker in current_diar.itertracks(yield_label=True):
    y_label = f"mic_{i}_{speaker}"

    #this is flat signal with stacked source separatin channels 

    # print(turn.start, turn.start - (300 * (turn.start // 300)))
    start = np.rint(turn.start % 300)
    end = np.rint(turn.end % 300)

    # Handle special cases
    if end < start and np.allclose(start, 300):
        start = 0

    if start > end and np.allclose(end, 0):
        end = 300

        
    # if y_label == "mic_2_SPEAKER_03":
    #     print(start,end,turn.start,turn.end)


    #calculate max amplitude of this segment

    segment = pipeline._audio.crop(segment=turn,file=audio_in_memory)
    segment=segment[0]

    segment = segment.squeeze(0).numpy()

    def calculate_energy(signal):
        return np.sum(signal ** 2)

    energy = calculate_energy(segment)

    #energy values are rea

    #use plotly colors scale to get color if representitive speaker

    #check if the speaker is the representitive speaker
    #fin the index if the curent speaker in current_max_amplitude
    speaker_index = list(current_max_amplitude.keys()).index(speaker)
    # print(speaker_index)

    #get max amplitude of the speaker

    #get color based on index  use plotly colors make the first one really hot
    c = px.colors.sequential.Reds[(8-speaker_index)]

    #max_amplitude
    # print(max_amplitude)

    

    # cc = int(energy*100000)
    # #max at 100
    # cc = min(cc,5)
    # print(energy)
    # c = px.colors.sequential.Reds[cc]

    fig.add_trace(go.Scatter(
        x=[start,end],
        y=[y_label, y_label],
        mode='lines',
        line=dict(color=c, width=10),
        name=y_label,
        legendgroup=f"mic_{i}",
        showlegend=(y_label not in [trace.name for trace in fig.data])
    ))



# Update layout to set y-axis as category type and use the unique labels
fig.update_layout(
yaxis=dict(
    title='Speakers',
    tickmode='array',
    tickvals=unique_labels,
    ticktext=unique_labels,
    categoryorder='array',
    categoryarray=unique_labels
),
xaxis=dict(title='Time'),
title='Diarization Visualization',
legend=dict(title='Speakers', itemsizing='constant')
)

fig.show(renderer="browser")

In [None]:
#exatract audio of mic x  speaker x

mic_index =1
speaker_indexes= [5,1]
speaker_indexes = [f"SPEAKER_0{i}" for i in speaker_indexes]


#get audio singal at chunks of speaker 1
import numpy as np
import soundfile as sf


mic = data[mic_index]


sr = 16000

print(mic.shape)

#(30, 6, 160000) -> (6, 160000 * 30)

mic = einops.rearrange(mic, "a b c -> b (a c)")

print(mic.shape) 


#(6, 4800000) -> (14800000 * 6)

mic_full_flat = einops.rearrange(mic, "a b -> (a b)")

speaker1 = np.zeros_like(mic_full_flat)

print(mic_full_flat.shape) #(28800000,)

for turn, _, speaker in all_diars[mic_index].itertracks(yield_label=True):
    if speaker in speaker_indexes:
        start = int(turn.start * sr)
        end = int(turn.end * sr)
        speaker1[start:end] = mic_full_flat[start:end]

print(speaker1.shape)
#(28800000,) -> 4800000, 6)

#

#reshape it by sequentaly making chunks of 4800000

res = np.zeros((6,speaker1.shape[0]//6))
for i in range(6):
    res[i]=speaker1[i*speaker1.shape[0]//6:(i+1)*speaker1.shape[0]//6]

print(res.shape)






sf.write("outputttt3_mic55_flaten_speaker_x.wav", res.T, sr)

#TODO NO SAVING NEEDED
# audio_in_memory = {"waveform": waveform, "sample_rate": sample_rate}
# type(waveform)=<class 'torch.Tensor'>
# waveform.shape=torch.Size([1, 480000])
# waveform.dtype=torch.float32



