In [1]:
import os
import json

# is_valid_segment [DONE]
def is_valid_segment(segment, valid_speakers_length = 2, valid_speakers_ids = ['A', 'B']):
    speakers_ids = [speaker['speaker_id'] for speaker in segment['speakers']]
    speakers_ids = list(set(speakers_ids))
    return len(speakers_ids) <= valid_speakers_length and \
        all(speaker_id in valid_speakers_ids for speaker_id in speakers_ids)

# load_recordings_segments [DONE]
def load_recordings_segments(directory):
    filenames = [filename for filename in os.listdir(directory) if os.path.isfile(os.path.join(directory, filename))]
    filenames.sort()
    recordings_segments = {}
    recordings_length = len(filenames)
    recordings_count = 0
    segments_original = 0
    segments_filtered = 0
    for filename in filenames:
        recording_id = filename.split('.')[0]
        filepath = os.path.join(directory, filename)
        file = open(filepath, 'r')
        recordings_segments[recording_id] = [json.loads(line) for line in file.readlines()]
        file.close()
        segments_original += len(recordings_segments[recording_id])
        recordings_segments[recording_id] = list(filter(is_valid_segment, recordings_segments[recording_id]))
        segments_filtered += len(recordings_segments[recording_id])
        recordings_count += 1
        print(directory + ' loading ' + str(recordings_count) + '/' + str(recordings_length), end = '\r')
    print(directory, 'loaded', str(recordings_count) + '/' + str(recordings_length) + ',', round(segments_filtered / segments_original, 2), 'segments left.')
    return recordings_segments

In [2]:
callhome1_recordings_segments = load_recordings_segments('../exp/pre_norm/callhome1/json')

../exp/pre_norm/callhome1/json loaded 249/249, 0.74 segments left.


In [3]:
callhome2_recordings_segments = load_recordings_segments('../exp/pre_norm/callhome2/json')

../exp/pre_norm/callhome2/json loaded 250/250, 0.77 segments left.


In [4]:
from torch.utils.data import Dataset
from functools import reduce
import random
import numpy as np
import itertools

# speakers_get_indexes [DONE]
def speakers_get_indexes(accumulator, speakers_tuple):
    speaker_ids, index = speakers_tuple
    speaker_ids = ','.join(speaker_ids)
    if speaker_ids in accumulator:
        accumulator[speaker_ids].append(index)
    else:
        accumulator[speaker_ids] = [index]
    return accumulator

def generate_speaker_model(recording_segments,
                           speaker_indexes,
                           segments_length,
                           vector = 'ivectors',
                           selection = 'first',
                           indexes = []):
    if segments_length > len(speaker_indexes):
        print('WARNING: there are less speaker indexes than segments.')
    if selection == 'first':
        selected_segments = [segment for index, segment in enumerate(recording_segments) if index in speaker_indexes[:segments_length]]
    elif selection == 'random':
        selected_segments = [recording_segments[index] for index in random.sample(speaker_indexes, segments_length if segments_length < len(speaker_indexes) else len(speaker_indexes))]
    elif selection == 'indexes':
        selected_segments = [recording_segments[index] for index in indexes]
    else:
        print('ERROR: unknown speaker model segments selection strategy.')
    selected_vectors = [np.asarray(segment[vector][0]['value']) for segment in selected_segments]
    return np.sum(selected_vectors, 0) / len(selected_vectors)
    

class Recordings_dataset(Dataset):
    def __init__(self,
                 recordings_segments,
                 recordings_ids = None,
                 vector = 'ivectors',
                 models_container_length = 2,
                 models_container_include_zeros = True,
                 models_container_include_overlaps = False,
                 models_generation_lengths = [3, 4],
                 models_generation_selection = 'first',
                 balance_segments = True,
                 balance_segments_selection = 'copy'):
        # -----------------------------------------------------Saving input data----- #
        if recordings_ids is None:
            recordings_ids = [recording_id for recording_id in recordings_segments]
        self.recordings_ids = recordings_ids if isinstance(recordings_ids, list) else [recordings_ids]
        self.recordings_segments = {}
        for recording_id in self.recordings_ids:
            self.recordings_segments[recording_id] = recordings_segments[recording_id]
        self.vector = vector
        self.models_container_length = models_container_length
        self.models_container_include_zeros = models_container_include_zeros
        self.models_container_include_overlaps = models_container_include_overlaps
        self.models_generation_lengths = models_generation_lengths
        self.models_generation_selection = models_generation_selection
        self.balance_segments = balance_segments
        self.balance_segments_selection = balance_segments_selection
        # --------------------------------------------------------------------------- #
        self.recordings_data = {}
        # -------------------------------------------------- #
        self.recordings_map = []
        self.recordings_length = 0
        for recording_id in self.recordings_ids:
            self.recordings_data[recording_id] = {}
            recording_segments = self.recordings_segments[recording_id]
            recording_data = self.recordings_data[recording_id]
            # ----- Obtaining speakers indexes ----- #
            recording_data['speakers_indexes'] = [(sorted(list(set([speaker['speaker_id'] for speaker in segment['speakers']]))), index) for index, segment in enumerate(recording_segments)]
            recording_data['speakers_indexes'] = reduce(speakers_get_indexes, recording_data['speakers_indexes'], {})
            # ----- Balancing speakers segments ----- #
            recording_data['speakers_indexes_lengths_max'] = max([len(recording_data['speakers_indexes'][speakers_ids]) for speakers_ids in recording_data['speakers_indexes']])
            if self.balance_segments:
                if self.balance_segments_selection == 'copy':
                    for speakers_ids in recording_data['speakers_indexes']:
                        for i in range(recording_data['speakers_indexes_lengths_max'] - len(recording_data['speakers_indexes'][speakers_ids])):
                            index = random.choice(recording_data['speakers_indexes'][speakers_ids])
                            recording_segments.append(recording_segments[index])
                            recording_data['speakers_indexes'][speakers_ids].append(len(recording_segments) - 1)
                else:
                    print('ERROR: unknown balancing segments selection strategy.')
            # ----- Generating speakers models ----- #
            recording_data['speakers_models'] = {}
            for speakers_ids in recording_data['speakers_indexes']:
                recording_data['speakers_models'][speakers_ids] = {}
                for models_generation_length in models_generation_lengths:
                    speakers_model = generate_speaker_model(recording_segments, recording_data['speakers_indexes'][speakers_ids], models_generation_length, self.vector, self.models_generation_selection)
                    recording_data['speakers_models'][speakers_ids][models_generation_length] = [speakers_model]
            # ----- Generating permutations ----- #
            if self.models_container_include_zeros:
                recording_data['permutations'] = list(itertools.permutations(list(recording_data['speakers_models'].keys()) \
                + ['0' for i in range(self.models_container_length)], self.models_container_length))
            else:
                recording_data['permutations'] = list(itertools.permutations(list(recording_data['speakers_models'].keys()), self.models_container_length))
            recording_data['permutations'] = list(set(recording_data['permutations']))
            recording_data['permutations'].sort()
            if not self.models_container_include_overlaps:
                recording_data['permutations'] = [permutation for permutation in recording_data['permutations'] if all(len(speakers_ids.split(',')) == 1 for speakers_ids in permutation)]
            # -------------------------------------------------- #
            recording_data['permutations_map'] = []
            recording_data['permutations_length'] = 0
            for index, permutation in enumerate(recording_data['permutations']):
                speakers_models_length = int(np.prod([np.sum([len(recording_data['speakers_models'][speakers_ids][models_generation_length]) for models_generation_length in recording_data['speakers_models'][speakers_ids]]) for speakers_ids in permutation if speakers_ids != '0']))
                recording_data['permutations_map'].append((recording_data['permutations_length'], recording_data['permutations_length'] + speakers_models_length - 1, index))
                recording_data['permutations_length'] += speakers_models_length
            recording_data['length'] = len(recording_segments) * recording_data['permutations_length']
            self.recordings_map.append((self.recordings_length, self.recordings_length + recording_data['length'] - 1, recording_id))
            self.recordings_length += recording_data['length']
    def __len__(self):
        return self.recordings_length
    def __getitem__(self, idx):
        recording_limits = list(filter(lambda recording_limits: recording_limits[0] <= idx and idx <= recording_limits[1], self.recordings_map))[0]
        recording_idx = idx - recording_limits[0]
        recording_id = recording_limits[2]
        recording_data = self.recordings_data[recording_id]
        
        segment_index, segment_idx = divmod(recording_idx, recording_data['permutations_length'])
        segment = self.recordings_segments[recording_id][segment_index]
        vector = np.asarray(segment[self.vector][0]['value'])
        
        permutation_limits = list(filter(lambda permutation_limits: permutation_limits[0] <= segment_idx and segment_idx <= permutation_limits[1], recording_data['permutations_map']))[0]
        permutation_idx = segment_idx - permutation_limits[0]
        permutation_index = permutation_limits[2]
        permutation = recording_data['permutations'][permutation_index]
        
        speakers_models_lengths = [np.sum([len(recording_data['speakers_models'][speakers_ids][models_generation_length]) for models_generation_length in recording_data['speakers_models'][speakers_ids]])  if speakers_ids != '0' else 1 for speakers_ids in permutation]
        models_container = []
        model_index = permutation_idx
        for i, length_i in enumerate(speakers_models_lengths):
            if i != len(speakers_models_lengths) - 1:
                model_index, remainder = divmod(model_index, np.sum(speakers_models_lengths[i + 1:]))
            else:
                model_index = remainder
            models_container.append(recording_data['speakers_models'][permutation[i]][self.models_generation_lengths[model_index]][0] if permutation[i] != '0' else np.random.uniform(-0.1, 0.1, len(vector)))
        
        models_weigths = np.asarray([len(recording_data['speakers_indexes'][speakers_ids]) if speakers_ids != '0' else recording_data['speakers_indexes_lengths_max'] for speakers_ids in permutation])
        models_weigths_sum = np.sum(models_weigths)
        models_weigths = np.ones(len(models_weigths)) - models_weigths / models_weigths_sum
        
        targets_ids = [speaker['speaker_id'] for speaker in segment['speakers']]
        
        x = [vector] + models_container
        if self.models_container_include_overlaps:
            targets_ids = ','.join(sorted(list(set(targets_ids))))
            y = np.asarray([speakers_ids == targets_ids for speakers_ids in permutation], dtype = float)
        else:
            y = np.asarray([speaker_id in targets_ids for speaker_id in permutation], dtype = float) / len(targets_ids)
        z = models_weigths
        
        return x, y, z

In [5]:
callhome1_dataset = Recordings_dataset(callhome1_recordings_segments)
print(len(callhome1_dataset))
print(callhome1_dataset[37])

1150356
([array([ 0.3519602 ,  4.005612  , -0.9051194 ,  0.427141  , -0.5701674 ,
        1.003131  ,  2.235364  ,  1.921338  , -3.015491  , -1.534834  ,
        1.217496  ,  0.710883  , -0.4937762 , -1.104113  , -0.9110618 ,
       -0.167288  ,  1.414095  , -0.1112232 , -0.6965439 ,  0.6867012 ,
       -0.7766348 , -0.04425273,  0.4530951 , -1.876562  , -0.04310996,
       -0.7285927 , -2.392449  , -1.250621  , -0.370301  , -0.9466311 ,
       -1.463477  , -0.7710519 ,  1.655105  , -0.7872185 ,  0.2859635 ,
        0.8824552 ,  0.1180109 , -0.1328261 ,  0.6444096 , -1.067963  ,
        0.4666332 ,  0.680375  ,  0.00625678,  1.387998  , -1.354118  ,
        0.4809265 ,  0.5807109 , -0.762727  , -0.3915525 ,  1.229105  ,
       -1.247902  , -0.3112767 , -0.1510374 ,  1.076944  , -1.710107  ,
       -0.2918752 , -0.5748917 ,  1.470002  ,  0.5160106 , -0.925755  ,
       -1.05996   , -0.7791965 ,  0.4565023 , -0.0949901 ,  0.04555997,
       -1.829623  ,  1.14681   , -0.4794345 , -1.45516