In [65]:
import glob
import librosa
import numpy as np
import os
import torch

from hparam import hparam as hp
from speech_embedder_net import SpeechEmbedder
from VAD_segments import VAD_chunk
from dvector_create import concat_segs, get_STFTs, align_embeddings
import tqdm

from scipy.spatial import distance
from collections import Counter

In [None]:
device = torch.device("cuda")

In [4]:
embedder_net = SpeechEmbedder()
embedder_net.load_state_dict(torch.load(hp.model.model_path))
embedder_net.eval()
embedder_net.to(device)

SpeechEmbedder(
  (LSTM_stack): LSTM(40, 768, num_layers=3, batch_first=True)
  (projection): Linear(in_features=768, out_features=256, bias=True)
)

In [None]:
wav_to_seq = {}
wav_to_speaker = {}

for file in tqdm.tqdm_notebook(glob.glob("./data/*/16/*wav")):
    times, segs = VAD_chunk(2, file)

    if not len(segs):
        print('No voice activity detected')
        continue

    concat_seg = concat_segs(times, segs)
    STFT_frames = get_STFTs(concat_seg)
    STFT_frames = np.stack(STFT_frames, axis=2)
    STFT_frames = torch.tensor(np.transpose(STFT_frames, axes=(2,1,0)))
    embeddings = embedder_net(STFT_frames.to(device))
    aligned_embeddings = align_embeddings(embeddings.detach().cpu().numpy())

    file_name = os.path.basename(file)
    speaker_name = os.path.dirname(file).split("/")[2]
    wav_to_seq[file_name] = aligned_embeddings
    wav_to_speaker[file_name] = speaker_name

HBox(children=(IntProgress(value=0, max=2248), HTML(value='')))

### Load etalons

In [49]:
etalons = {
    "ruslan": wav_to_seq["007458_RUSLAN.wav"],
    "navalny": wav_to_seq["4a23b379-abd9-4e6a-a58f-6fac50475292.wav"],
    "urgant": wav_to_seq["5191d6cf-46c3-4815-a27e-816b672c525b.wav"]
}

In [86]:
def get_speakers(sequence, max_dist=0.4, min_speaker_ratio=0.6):
    results = []
    
    for seq in sequence:
        min_dist, min_etalon = None, None
        for etalon in etalons:
            etalon_dist = np.mean([distance.cosine(seq, etalon_seq) for etalon_seq in etalons[etalon]])
            if (min_dist is None and etalon_dist <= max_dist) or (min_dist is not None and etalon_dist < min_dist):
                min_dist = etalon_dist
                min_etalon = etalon
            
        if min_etalon is not None:
            results.append(min_etalon)
        else:
            results.append("unknown")
           
    print(results)
    counts = Counter(results)
    most_common = counts.most_common(1)[0]
    if most_common[0] != "unknown" and (most_common[1] / len(results)) >= min_speaker_ratio:
        return most_common[0]
    else:
        return None

In [87]:
get_speakers(wav_to_seq["011158_RUSLAN.wav"])

['ruslan', 'ruslan', 'ruslan', 'ruslan', 'ruslan', 'ruslan', 'unknown', 'ruslan', 'ruslan', 'ruslan', 'unknown', 'ruslan', 'ruslan', 'ruslan']


'ruslan'