In [1]:
from gss.distribution import CACGMMTrainer
import numpy as np
from gss.utils import stack_parameters
import torchaudio

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
# Get Guidance from Transcript
from gss import mapping
from gss.utils import ArrayIntervall

def get_activity(
        iterator,
        *,
        perspective,
        garbage_class,
        dtype=np.bool,
        non_sil_alignment_fn=None,
        debug=False,
        use_ArrayIntervall=False,
):

    dict_it_S = iterator.groupby(lambda ex: ex['session_id'])

    # Dispatcher is a dict with better KeyErrors
    all_acitivity = Dispatcher()
    for session_id, it_S in dict_it_S.items():

        if perspective == 'worn':
            perspective_tmp = session_to_speakers[session_id]
        elif perspective == 'global_worn':
            perspective_tmp = ['P']  # Always from target speaker
        elif perspective == 'array':
            # The mapping considers missing arrays
            perspective_tmp = session_to_arrays[session_id]
        else:
            perspective_tmp = perspective

            if not isinstance(perspective_tmp, (tuple, list)):
                perspective_tmp = [perspective_tmp, ]

        speaker_ids = mapping.session_to_speakers[session_id]

        if use_ArrayIntervall:
            assert dtype == np.bool, dtype
            zeros = ArrayIntervall

            def ones(shape):
                arr = zeros(shape=shape)
                arr[:] = 1
                return arr
        else:
            import functools
            zeros = functools.partial(np.zeros, dtype=dtype)
            ones = functools.partial(np.ones, dtype=dtype)

        all_acitivity[session_id] = Dispatcher({
            p: Dispatcher({
                s: zeros(shape=[session_array_to_num_samples[f'{session_id}_{p}']])
                # s: ArrayIntervall(shape=[num_samples])
                for s in speaker_ids
            })
            for p in perspective_tmp
        })

        if garbage_class is True:
            for p in perspective_tmp:
                num_samples = session_array_to_num_samples[
                    f'{session_id}_{p}']
                all_acitivity[session_id][p]['Noise'] = ones(
                    shape=[num_samples],
                )
        elif garbage_class is False:
            for p in perspective_tmp:
                num_samples = session_array_to_num_samples[
                    f'{session_id}_{p}']
                all_acitivity[session_id][p]['Noise'] = zeros(
                    shape=[num_samples]
                )
        elif garbage_class is None:
            pass
        elif isinstance(garbage_class, int) and garbage_class > 0:
            for noise_idx in range(garbage_class):
                for p in perspective_tmp:
                    num_samples = session_array_to_num_samples[
                        f'{session_id}_{p}'
                    ]
                    all_acitivity[session_id][p][f'Noise{noise_idx}'] = ones(
                        shape=[num_samples]
                    )
        else:
            raise ValueError(garbage_class)

        missing_count = 0
        for ex in it_S:
            for pers in perspective_tmp:
                if ex['transcription'] == '[redacted]':
                    continue

                target_speaker = ex['speaker_id']
                # example_id = ex['example_id']

                if pers == 'P':
                    perspective_mic_array = target_speaker
                else:
                    perspective_mic_array = pers

                if perspective_mic_array.startswith('P'):
                    start = ex['start']['worn'][perspective_mic_array]
                    end = ex['end']['worn'][perspective_mic_array]
                else:
                    if not perspective_mic_array in ex['audio_path']['observation']:
                        continue
                    start = ex['start']['observation'][perspective_mic_array]
                    end = ex['end']['observation'][perspective_mic_array]

                if non_sil_alignment_fn is None:
                    value = 1
                else:
                    value = non_sil_alignment_fn(ex, perspective_mic_array)
                    if value is 1:
                        missing_count += 1

                if debug:
                    all_acitivity[session_id][pers][target_speaker][start:end] += value
                else:
                    all_acitivity[session_id][pers][target_speaker][start:end] = value
        if missing_count > len(it_S) // 2:
            raise RuntimeError(
                f'Something went wrong.\n'
                f'Expected {len(it_S) * len(perspective_tmp)} times a '
                f'finetuned annotation for session {session_id}, but '
                f'{missing_count} times they are missing.\n'
                f'Expect that at least {len(it_S) // 2} finetuned annotations '
                f'are available, when non_sil_alignment_fn is given.\n'
                f'Otherwise assume something went wrong.'
            )

        del it_S

    return all_acitivity


  if value is 1:
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  dtype=np.bool,


In [5]:
import torch
import torchaudio
import numpy as np
from nara_wpe.wpe import wpe
import os
from tqdm import tqdm

In [3]:
class GSS:
    iterations: int
    iterations_post: int

    verbose: bool = True

    # use_pinv: bool = False
    # stable: bool = True

    def __call__(self, Obs, acitivity_freq, debug=False):

        initialization = np.asarray(acitivity_freq, dtype=np.float64)
        initialization = np.where(initialization == 0, 1e-10, initialization)
        initialization = initialization / np.sum(initialization, keepdims=True,
                                                axis=0)
        initialization = np.repeat(initialization[None, ...], 513, axis=0)

        source_active_mask = np.asarray(acitivity_freq, dtype=np.bool)
        source_active_mask = np.repeat(source_active_mask[None, ...], 513, axis=0)

        cacGMM = CACGMMTrainer()

        if debug:
            learned = []
        all_affiliations = []
        F = Obs.shape[-1]
        T = Obs.T.shape[-2]
        for f in range(F):
            if self.verbose:
                if f % 50 == 0:
                    print(f'{f}/{F}')

            # T: Consider end of signal.
            # This should not be nessesary, but activity is for inear and not for
            # array.
            cur = cacGMM.fit(
                y=Obs.T[f, ...],
                initialization=initialization[f, ..., :T],
                iterations=self.iterations,
                source_activity_mask=source_active_mask[f, ..., :T],
                # return_affiliation=True,
            )

            if self.iterations_post != 0:
                if self.iterations_post != 1:
                    cur = cacGMM.fit(
                        y=Obs.T[f, ...],
                        initialization=cur,
                        iterations=self.iterations_post - 1,
                    )
                affiliation = cur.predict(
                    Obs.T[f, ...],
                )
            else:
               affiliation = cur.predict(
                   Obs.T[f, ...],
                   source_activity_mask=source_active_mask[f, ..., :T]
               )

            if debug:
                learned.append(cur)
            all_affiliations.append(affiliation)

        posterior = np.array(all_affiliations).transpose(1, 2, 0)

        if debug:
            learned = stack_parameters(learned)
            self.locals = locals()

        return posterior

In [4]:
datapath = "/Users/danilfedorovsky/Documents/10 Collection/00 Studium/00 Letztes Semester/Masterarbeit/Code/Git Repo/data/1WPE/"

data, sample_rate = torchaudio.load(datapath+"WPE_S02_P05.wav")
data = data.detach().numpy()

In [5]:
obs = np.array(10)
gss = GSS().__call__(obs,data)

  initialization = initialization / np.sum(initialization, keepdims=True,
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  source_active_mask = np.asarray(acitivity_freq, dtype=np.bool)


IndexError: tuple index out of range