In [1]:
import os
import json

def is_valid_segment(segment):
    return len(segment['speakers']) == 1 \
            and len(segment['ivectors']) == 1 \
            and len(segment['xvectors']) == 1 \
            and segment['speakers'][0]['speaker_id'] in ['A', 'B']

directory = '../exp/json'
filenames = [filename for filename in os.listdir(directory) if os.path.isfile(os.path.join(directory, filename))]

recordings_segments = {}
recordings_length = len(filenames)
recordings_count = 0
for filename in filenames:
    recording_id = filename.split('.')[0]
    filepath = os.path.join(directory, filename)
    file = open(filepath, 'r')
    recordings_segments[recording_id] = [json.loads(line) for line in file.readlines()]
    file.close()
    recordings_segments[recording_id] = list(filter(is_valid_segment, recordings_segments[recording_id]))
    recordings_count += 1
    print('Loading data: ' + str(recordings_count) + '/' + str(recordings_length), end = '\r')

Loading data: 249/249

## Balancing the dataset

In [2]:
from functools import reduce

def speakers_get_indexes(accumulator, speaker_tuple):
    speaker_id, index = speaker_tuple
    if speaker_id in accumulator:
        accumulator[speaker_id].append(index)
    else:
        accumulator[speaker_id] = [index]
    return accumulator

recordings_segments_cut = {}
for recording_id in recordings_segments:
    recording_segments = recordings_segments[recording_id]
    speakers_indexes = [(segment['speakers'][0]['speaker_id'], index) for index, segment in enumerate(recording_segments)]
    speakers_indexes = reduce(speakers_get_indexes, speakers_indexes, {})
    speakers_lengths = [(speaker_id, len(speakers_indexes[speaker_id])) for speaker_id in speakers_indexes]
    speakers_lengths.sort(key = lambda x: x[1])
    speakers_lengths_min = speakers_lengths[0][1]
    if len(speakers_lengths) > 1 and speakers_lengths_min >= 20: # <-- IMPORTANT
        recording_indexes = []
        for speaker_id in speakers_indexes:
            speakers_indexes[speaker_id] = speakers_indexes[speaker_id][:speakers_lengths_min]
            recording_indexes += speakers_indexes[speaker_id]
        recordings_segments_cut[recording_id] = [segment for index, segment in enumerate(recordings_segments[recording_id]) if index in recording_indexes]
print('Recordings left:', len(recordings_segments_cut))

Recordings left: 172


## Dataset class inheritance

In [3]:
from torch.utils.data import Dataset
import numpy as np
import itertools

models_generation_length = 3 # <-- IMPORTANT
models_container_length = 2  # <-- IMPORTANT
permutations_include_zeros = False

class Recordings_dataset(Dataset):
    def __init__(self, recordings_segments, recordings_ids, mode = 'ivectors'):
        self.recordings_ids = recordings_ids if isinstance(recordings_ids, list) else [recordings_ids]
        self.recordings_segments = {}
        for recording_id in self.recordings_ids:
            self.recordings_segments[recording_id] = recordings_segments[recording_id]
        self.mode = mode
        self.models_generation_length = models_generation_length
        self.models_container_length = models_container_length
        self.permutations_include_zeros = permutations_include_zeros
        self.recordings_data = {}
        self.recordings_map = []
        self.recordings_length = 0
        for recording_id in self.recordings_ids:
            self.recordings_data[recording_id] = {}
            recording_segments = self.recordings_segments[recording_id]
            recording_data = self.recordings_data[recording_id]
            recording_data['speakers_indexes'] = [(segment['speakers'][0]['speaker_id'], index) for index, segment in enumerate(recording_segments)]
            recording_data['speakers_indexes'] = reduce(speakers_get_indexes, recording_data['speakers_indexes'], {})
            recording_data['speakers_indexes_lengths_max'] = max([len(recording_data['speakers_indexes'][speaker_id]) for speaker_id in recording_data['speakers_indexes']])
            recording_data['speakers_models'] = {}
            for speaker_id in recording_data['speakers_indexes']:
                speaker_indexes = recording_data['speakers_indexes'][speaker_id]
                speaker_vectors = [np.asarray(recording_segments[index][self.mode][0]['value']) for index in speaker_indexes[:self.models_generation_length]]
                recording_data['speakers_models'][speaker_id] = [np.sum(speaker_vectors, 0) / len(speaker_vectors)]
            if self.permutations_include_zeros:
                recording_data['permutations'] = list(itertools.permutations(list(recording_data['speakers_models'].keys()) \
                + ['0' for i in range(self.models_container_length)], self.models_container_length))
            else:
                recording_data['permutations'] = list(itertools.permutations(list(recording_data['speakers_models'].keys()), self.models_container_length))
            recording_data['permutations'] = list(set(recording_data['permutations']))
            recording_data['permutations'].sort()
            recording_data['permutations_map'] = []
            recording_data['permutations_length'] = 0
            for index, permutation in enumerate(recording_data['permutations']):
                speakers_models_length = int(np.prod([len(recording_data['speakers_models'][speaker_id]) for speaker_id in permutation if speaker_id != '0']))
                recording_data['permutations_map'].append((recording_data['permutations_length'], recording_data['permutations_length'] + speakers_models_length - 1, index))
                recording_data['permutations_length'] += speakers_models_length
            recording_data['length'] = len(recording_segments) * recording_data['permutations_length']
            self.recordings_map.append((self.recordings_length, self.recordings_length + recording_data['length'] - 1, recording_id))
            self.recordings_length += recording_data['length']
    def __len__(self):
        return self.recordings_length
    def __getitem__(self, idx):
        recording_tuple = list(filter(lambda recording_tuple: recording_tuple[0] <= idx and idx <= recording_tuple[1], self.recordings_map))[0]
        recording_idx = idx - recording_tuple[0]
        recording_id = recording_tuple[2]
        recording_data = self.recordings_data[recording_id]
        
        segment_id, segment_idx = divmod(recording_idx, recording_data['permutations_length'])
        segment = self.recordings_segments[recording_id][segment_id]
        target_id = segment['speakers'][0]['speaker_id']
        vector = np.asarray(segment[self.mode][0]['value'])
        
        permutation_tuple = list(filter(lambda permutation_tuple: permutation_tuple[0] <= segment_idx and segment_idx <= permutation_tuple[1], recording_data['permutations_map']))[0]
        permutation_id = permutation_tuple[2]
        permutation = recording_data['permutations'][permutation_id]
        
        models_container = [np.asarray(recording_data['speakers_models'][speaker_id][0]) if speaker_id != '0' else np.zeros(len(vector)) for speaker_id in permutation]
        models_weigths = np.asarray([len(recording_data['speakers_indexes'][speaker_id]) if speaker_id != '0' else recording_data['speakers_indexes_lengths_max'] for speaker_id in permutation])
        models_weigths_sum = np.sum(models_weigths)
        models_weigths = np.ones(len(models_weigths)) - models_weigths / models_weigths_sum
        
        x = [vector] + models_container
        y = np.asarray([speaker_id == target_id for speaker_id in permutation], dtype = float)
        z = models_weigths
        
        return x, y, z

## DNN model

In [4]:
import torch.nn as nn
import torch.nn.functional as F

n = models_container_length

class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.cnn1 = nn.Sequential(
            nn.Conv1d((n + 1), (n + 1), 3, padding = 1),
            nn.ReLU(),
            nn.Dropout(p = 0.2),
        )
        self.fc1 = nn.Sequential(
            nn.Linear((n + 1) * 128, 16),
            nn.ReLU(),
            nn.Dropout(p = 0.5),
            nn.Linear(16, n),
            nn.Sigmoid(),
        )
        
    def forward(self, input):
        x = torch.stack(input, 1)
        x = self.cnn1(x)
        x = x.view(x.shape[0], -1)
        x = self.fc1(x)
        return x

In [5]:
from torch.utils.data import DataLoader, random_split

recordings_ids = [recording_id for recording_id in recordings_segments_cut]
recordings_ids_cut = int(len(recordings_ids) * 0.7)
print(recordings_ids_cut)

train_dataset = Recordings_dataset(recordings_segments_cut, recordings_ids[:recordings_ids_cut])
test_dataset = Recordings_dataset(recordings_segments_cut, recordings_ids[recordings_ids_cut:])

train_length = int(len(train_dataset) * 0.7)
validation_length = len(train_dataset) - train_length

train_dataset, validation_dataset = random_split(train_dataset, [train_length, validation_length])
train_dataloader = DataLoader(train_dataset, batch_size = 10, shuffle=True, num_workers = 8)
validation_dataloader = DataLoader(validation_dataset, batch_size = 10, num_workers = 8)
test_dataloader = DataLoader(test_dataset, batch_size = 10, num_workers = 8)

120


In [6]:
import torch
if torch.cuda.is_available():
    device = torch.device('cuda:0')
    print('Running on the GPU.')
else:
    device = torch.device('cpu')
    print('Running on the CPU.')

Running on the GPU.


In [7]:
%matplotlib notebook
import matplotlib.pyplot as plt

class Live_graph:
    def __init__(self, validation_threshold):
        self.plt_count = -1
        self.validation_threshold = validation_threshold
        self.plt_thr = ([self.plt_count], [self.validation_threshold])
        self.plt_loss = ([self.plt_count], [1])
        self.plt_valid = ([self.plt_count], [1])
        self.plt_test = ([self.plt_count], [1])
        self.fig = plt.figure()
        self.ax = self.fig.add_subplot()
        self.line0, = self.ax.plot(self.plt_thr[0], self.plt_thr[1], 'k--', label = 'Threshold') # Threshold line
        self.line1, = self.ax.plot(self.plt_loss[0], self.plt_loss[1], '--', label = 'Training') # Training loss
        self.line2, = self.ax.plot(self.plt_valid[0], self.plt_valid[1], label = 'Validation')   # Validation loss
        self.line3, = self.ax.plot(self.plt_test[0], self.plt_test[1], label = 'Test')           # Test loss
        self.ax.set_xlabel('Epoch')
        self.ax.set_ylabel('Loss')
        self.ax.legend()
        self.ax.set_xlim(-1, 0)
        self.fig.canvas.draw()
        self.fig.canvas.flush_events()
    def step(self, training, validation, test):
        self.plt_count += 1
        self.plt_thr[0].append(self.plt_count)
        self.plt_thr[1].append(self.validation_threshold)
        self.plt_loss[0].append(self.plt_count)
        self.plt_loss[1].append(training)
        self.plt_valid[0].append(self.plt_count)
        self.plt_valid[1].append(validation)
        self.plt_test[0].append(self.plt_count)
        self.plt_test[1].append(test)
        self.line0.set_xdata(self.plt_thr[0])
        self.line0.set_ydata(self.plt_thr[1])
        self.line1.set_xdata(self.plt_loss[0])
        self.line1.set_ydata(self.plt_loss[1])
        self.line2.set_xdata(self.plt_valid[0])
        self.line2.set_ydata(self.plt_valid[1])
        self.line3.set_xdata(self.plt_test[0])
        self.line3.set_ydata(self.plt_test[1])
        self.ax.set_xlim(0, self.plt_count + 1)
        self.fig.canvas.draw()
        self.fig.canvas.flush_events()

In [8]:
import torch.optim as optim

net = Net().to(device)
optimizer = optim.Adam(net.parameters(), lr = 0.0004) # 0.0004 GOOD

epochs = 10
validation_threshold = 0.1

live_graph = Live_graph(validation_threshold)

for epoch in range(epochs):
    train_losses = []
    for input, target, weigth in train_dataloader:
        input = [tensor.to(device, non_blocking = True).float() for tensor in input]
        target = target.to(device, non_blocking = True).float()
        weigth = weigth.to(device, non_blocking = True).float()
        
        criterion = nn.BCELoss(weigth)
        net.zero_grad()
        output = net(input)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()

        train_losses.append(loss.data)
        print('train: ' + str(len(train_losses)) + '/' + str(len(train_dataloader)) + '     ', end = '\r')
    train_loss = np.sum(train_losses) / len(train_losses)
        
        
    
    with torch.no_grad():
        validation_losses = []
        for input, target, weigth in validation_dataloader:
            input = [tensor.to(device, non_blocking = True).float() for tensor in input]
            target = target.to(device, non_blocking = True).float()
            weigth = weigth.to(device, non_blocking = True).float()
            
            criterion = nn.BCELoss(weigth)
            output = net(input)
            loss = criterion(output, target)
            validation_losses.append(loss.data)
            print('validation: ' + str(len(validation_losses)) + '/' + str(len(validation_dataloader)) + '     ', end = '\r')
        validation_loss = np.sum(validation_losses) / len(validation_losses)
    
        test_losses = []
        for input, target, weigth in test_dataloader:
            input = [tensor.to(device, non_blocking = True).float() for tensor in input]
            target = target.to(device, non_blocking = True).float()
            weigth = weigth.to(device, non_blocking = True).float()
            
            criterion = nn.BCELoss(weigth)
            output = net(input)
            loss = criterion(output, target)
            test_losses.append(loss.data)
            print('test: ' + str(len(test_losses)) + '/' + str(len(test_dataloader)) + '     ', end = '\r')
        test_loss = np.sum(test_losses) / len(test_losses)
            
    live_graph.step(train_loss, validation_loss, test_loss)
    
    if validation_loss <= validation_threshold:
        print('Done training.')
        break

<IPython.core.display.Javascript object>

Done training.    8     


In [9]:
test_dataloader = DataLoader(test_dataset, batch_size = 1, num_workers = 8)
test_count = 0
correct = 0
with torch.no_grad():
    for input, target, weigth in test_dataloader:
        input = [tensor.to(device, non_blocking = True).float() for tensor in input]
        target = target.to(device, non_blocking = True).float()
        weigth = weigth.to(device, non_blocking = True).float()
        output = net(input)
        if target.max(1)[1] == output.max(1)[1]:
            correct += 1
        test_count += 1
        print('test: ' + str(test_count) + '/' + str(len(test_dataloader)) + ' accuracy: ' + str(correct / test_count), end = '\r')

test: 8324/8324 accuracy: 0.9180682364247957

In [10]:
results = ''

test_recordings_ids = recordings_ids[recordings_ids_cut:]
test_recordings_ids.sort()

for recording_id in test_recordings_ids:
    recording_dataset = Recordings_dataset(recordings_segments_cut, recording_id)
    speakers_models = recording_dataset.recordings_data[recording_id]['speakers_models']
    # At this point there is no information about the speaker identity, only the model
    speakers_models = [speakers_models[speaker_id][0] for speaker_id in speakers_models]
    last_speaker_index = -1
    last_speaker = { 'begining': 0, 'ending': 0, 'index': -1 }
    for segment in recordings_segments[recording_id]:
        begining = segment['begining']
        ending = segment['ending']
        vector = np.asarray(segment['ivectors'][0]['value'])
        with torch.no_grad():
            input = [torch.Tensor([nparray]).to(device, non_blocking = True).float() for nparray in [vector] + speakers_models]
            output = net(input)
            index = output.max(1)[1].cpu().data.numpy()[0]
            if last_speaker_index != index:
                if last_speaker_index != -1:
                    result = 'SPEAKER ' + recording_id + ' 0 ' + str(last_speaker['begining']) + ' ' + str(round(last_speaker['ending'] - last_speaker['begining'], 2)) + ' <NA> <NA> ' + str(last_speaker['index']) + ' <NA> <NA>'
                    results += result + '\n'
                last_speaker_index = index
                last_speaker = { 'begining': begining, 'ending': ending, 'index': index }
            else:
                if begining <= last_speaker['ending']:
                    last_speaker['ending'] = ending
                else:
                    if last_speaker_index != -1:
                        result = 'SPEAKER ' + recording_id + ' 0 ' + str(last_speaker['begining']) + ' ' + str(round(last_speaker['ending'] - last_speaker['begining'], 2)) + ' <NA> <NA> ' + str(last_speaker['index']) + ' <NA> <NA>'
                        results += result + '\n'
                    last_speaker_index = index
                    last_speaker = { 'begining': begining, 'ending': ending, 'index': index }
                        

groundtruth_rttm_filepath = '../callhome1_1.0_0.5.rttm'
file = open(groundtruth_rttm_filepath, 'r')
test_rttm = ''.join([line for line in file.readlines() if (line.split(' ')[1] in test_recordings_ids) and \
                    (line.split(' ')[7] in ['A', 'B'])])
file.close()

file = open('test_results.rttm', 'w')
file.write(results)
file.close()

file = open('test_groundtruth.rttm', 'w')
file.write(test_rttm)
file.close()

*** Performance analysis for Speaker Diarization for ALL ***

    EVAL TIME =   7502.91 secs
  EVAL SPEECH =   3807.20 secs ( 50.7 percent of evaluated time)
  SCORED TIME =   7502.91 secs (100.0 percent of evaluated time)
SCORED SPEECH =   3807.20 secs ( 50.7 percent of scored time)
   EVAL WORDS =      0        
 SCORED WORDS =      0         (100.0 percent of evaluated words)
---------------------------------------------
MISSED SPEECH =     81.75 secs (  1.1 percent of scored time)
FALARM SPEECH =      0.01 secs (  0.0 percent of scored time)
 MISSED WORDS =      0         (100.0 percent of scored words)
---------------------------------------------
SCORED SPEAKER TIME =   3807.20 secs (100.0 percent of scored speech)
MISSED SPEAKER TIME =     81.75 secs (  2.1 percent of scored speaker time)
FALARM SPEAKER TIME =    248.01 secs (  6.5 percent of scored speaker time)
 SPEAKER ERROR TIME =    193.09 secs (  5.1 percent of scored speaker time)
SPEAKER ERROR WORDS =      0         (100.0 percent of scored speaker words)
---------------------------------------------
 OVERALL SPEAKER DIARIZATION ERROR = 13.73 percent of scored speaker time  `(ALL)
---------------------------------------------
 Speaker type confusion matrix -- speaker weighted
  REF\SYS (count)      unknown               MISS              
unknown                 104 / 100.0%          0 /   0.0%
  FALSE ALARM             0 /   0.0%
---------------------------------------------
 Speaker type confusion matrix -- time weighted
  REF\SYS (seconds)    unknown               MISS              
unknown             3725.45 /  97.9%      81.75 /   2.1%
  FALSE ALARM        248.01 /   6.5%
---------------------------------------------
