In [1]:
import os
import json
from copy import copy
import numpy
import itertools

In [2]:
folder = 'exp/json'
filenames = [f for f in os.listdir(folder) if os.path.isfile(os.path.join(folder, f))]
data = {}

def is_valid_segment(line):
    segment = json.loads(line)
    return len(segment['speakers']) == 1 \
            and segment['speakers'][0]['speaker_id'] in ['A', 'B'] \
            and len(segment['ivectors']) == 1 \
            and len(segment['xvectors']) == 1

for filename in filenames:
    recording_id = filename.split('.')[0]
    filepath = os.path.join(folder, filename)
    file = open(filepath)
    # Removing multiple-speaker, extra-speaker and multiple-vector segments
    indexes = [index for index, line in enumerate(file.readlines()) if is_valid_segment(line)]
    file.close()
    data[recording_id] = {}
    data[recording_id]['filepath'] = filepath
    data[recording_id]['indexes'] = indexes

In [3]:
oracle_lengths = [1, 2, 3, 4, 5]
models_container_size = 3
data_length = 0
data_indexes = []
delete_recordings_ids = []
for recording_id in data:
    file = open(data[recording_id]['filepath'])
    segments = [json.loads(line) for index, line in enumerate(file.readlines()) if index in data[recording_id]['indexes']]
    speakers_ids = list(set([segment['speakers'][0]['speaker_id'] for segment in segments]))
    if len(speakers_ids) > 1:
        data[recording_id]['oracle'] = {}
        for speaker_id in speakers_ids:
            speaker_ivectors, speaker_xvectors = zip(*[(segment['ivectors'][0]['value'], segment['xvectors'][0]['value']) for segment in segments if segment['speakers'][0]['speaker_id'] == speaker_id])
            data[recording_id]['oracle'][speaker_id] = {}
            for oracle_length in oracle_lengths:
                oracle_ivector_segments = numpy.array(speaker_ivectors[:oracle_length])
                oracle_ivector = oracle_ivector_segments.sum(axis = 0) / len(oracle_ivector_segments)
                oracle_xvector_segments = numpy.array(speaker_xvectors[:oracle_length])
                oracle_xvector = oracle_xvector_segments.sum(axis = 0) / len(oracle_xvector_segments)
                data[recording_id]['oracle'][speaker_id][oracle_length] = {}
                data[recording_id]['oracle'][speaker_id][oracle_length]['ivectors'] = [oracle_ivector]
                data[recording_id]['oracle'][speaker_id][oracle_length]['xvectors'] = [oracle_xvector]
        for speaker_id in speakers_ids:
            data[recording_id]['oracle']['0'] = {}
            for oracle_length in oracle_lengths:
                data[recording_id]['oracle']['0'][oracle_length] = {}
                data[recording_id]['oracle']['0'][oracle_length]['ivectors'] = [numpy.zeros(len(data[recording_id]['oracle'][speaker_id][oracle_length]['ivectors'][0]))]
                data[recording_id]['oracle']['0'][oracle_length]['xvectors'] = [numpy.zeros(len(data[recording_id]['oracle'][speaker_id][oracle_length]['xvectors'][0]))]
            break
        permutations = list(itertools.permutations(speakers_ids + ['0' for i in range(models_container_size)], models_container_size))
        permutations = list(set(permutations))
        data[recording_id]['permutations'] = permutations
        data[recording_id]['oracle_lengths'] = oracle_lengths
        length = len(segments) * len(oracle_lengths) * len(permutations)
        data_indexes.append((data_length, data_length + length - 1, recording_id))
        data_length += length
    else:
        delete_recordings_ids.append(recording_id)
for recording_id in delete_recordings_ids:
    del data[recording_id]

In [13]:
import torch
from torch.utils.data import Dataset, DataLoader

class SegmentsDataset(Dataset):
    def __init__(self, recordings_data, recordings_data_length, recordings_data_indexes, mode = 'ivectors'):
        self.recordings_data = recordings_data
        self.recordings_data_length = recordings_data_length
        self.recordings_data_indexes = recordings_data_indexes
        self.mode = mode
    def __len__(self):
        return self.recordings_data_length
    def __getitem__(self, idx):
        data_index = list(filter(lambda data_index: data_index[0] <= idx and idx <= data_index[1], self.recordings_data_indexes))[0]
        recording_id = data_index[2]
        recording_data = self.recordings_data[recording_id]
        recording_index = idx - data_index[0] # Index relative to the recording data
        segment_index, remainder = divmod(recording_index, (len(recording_data['oracle'][next(iter(recording_data['oracle']))]) * len(recording_data['permutations'])))
        segment_index = recording_data['indexes'][segment_index]
        oracle_index, permutation_index = divmod(remainder, len(recording_data['permutations']))
        segment = json.loads(open(recording_data['filepath']).readlines()[segment_index])
        vector = numpy.array(segment[self.mode][0]['value'])
        permutation = recording_data['permutations'][permutation_index]
        permutation = [recording_data['oracle'][speaker_id][recording_data['oracle_lengths'][oracle_index]][self.mode][0] for speaker_id in permutation]
        x = numpy.concatenate([vector] + permutation)
        y = numpy.asarray([speaker_id == segment['speakers'][0]['speaker_id'] for speaker_id in recording_data['permutations'][permutation_index]], dtype = float)
        return x, y
    
segmentsDataset = SegmentsDataset(data, data_length, data_indexes)
print(segmentsDataset[97999])

(array([-0.9675461 ,  1.531406  , -0.8428259 ,  0.00920541, -0.3674863 ,
       -0.4118845 ,  1.971635  ,  1.793968  , -2.073455  , -1.070194  ,
        1.160505  ,  0.1885651 , -0.9266022 ,  0.6533173 ,  0.9353606 ,
        1.670867  , -1.939562  , -1.689263  , -0.9686074 ,  1.135836  ,
        0.9751277 ,  0.688001  , -0.05608142, -0.8614132 , -2.14238   ,
       -0.8599296 ,  0.7985176 , -0.7097315 ,  0.1609198 ,  0.4908713 ,
        1.347724  , -0.2451703 ,  2.179703  , -1.384889  , -0.6609148 ,
       -1.750504  , -1.741976  , -2.464054  , -1.555332  ,  0.7351388 ,
       -0.3795526 ,  1.589345  ,  1.393199  ,  0.3918317 ,  1.497644  ,
        1.25397   ,  1.550706  , -1.285792  ,  0.1855451 , -0.6066136 ,
        1.646451  ,  1.371312  ,  1.017648  ,  0.212216  ,  0.0256269 ,
       -0.7001354 , -0.8706113 ,  1.134856  ,  0.1450322 , -0.7442414 ,
        1.423244  , -0.7569695 ,  0.4481492 , -0.1227947 , -0.2150244 ,
       -0.3418096 ,  1.572001  ,  1.688847  , -1.426033  , -0.7

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc0 = nn.Linear(128 * 3, 384)
        self.fc1 = nn.Linear(384, 384)
        self.fc2 = nn.Linear(384, 384)
        self.fc3 = nn.Linear(384, 192)
        self.fc4 = nn.Linear(192, 192)
        self.fc5 = nn.Linear(192, 192)
        self.fc6 = nn.Linear(192, 96)
        self.fc7 = nn.Linear(96, 2)
    def forward(self, x):
        x = F.relu(self.fc0(x))
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.relu(self.fc3(x))
        x = F.relu(self.fc4(x))
        x = F.relu(self.fc5(x))
        x = F.relu(self.fc6(x))
        x = self.fc7(x)
        x = F.log_softmax(x, dim = 1)
        return x

In [None]:
import torch.optim as optim

net = Net().cuda()

optimizer = optim.Adam(net.parameters(), lr = 0.5)

def custom_loss(input, target):
    return ((input - target) ** 2).sum() / input.data.nelement()

epochs = 10
for epoch in range(epochs):
    for entry in data:
        x = torch.FloatTensor(entry.x).cuda()
        y = torch.FloatTensor(entry.y).cuda()
        net.zero_grad()
        output = net(x.view(-1, 128 * 3))
        loss = custom_loss(output, y)
        loss.backward()
        optimizer.step()
    print(loss)