## Imports

In [208]:
import warnings
warnings.simplefilter('ignore')

import matplotlib.pyplot as plt
import json
import os
import re
import pandas as pd
import numpy as np
import torch
from matplotlib import cm
from scipy.io import wavfile
from sklearn.manifold import TSNE
from sklearn.cluster import KMeans

import _datasets as D
import _models as M

EPS = 1e-8

# select device
device = 'cpu'
if torch.cuda.is_available():
    device = 'cuda:0'

def reverse_non_unique_mapping(d):
    dinv = {}
    for k, v in d.items():
        k = str(k)
        v = int(v)
        if v in dinv:
            dinv[v].append(k)
        else:
            dinv[v] = [k]
    return dinv

## Generate speaker-mean vectors per checkpoint

In [104]:
checkpoint_path = 'weights/sv'

if os.path.exists(checkpoint_path):
    with torch.no_grad():
        
        with open(checkpoint_path + '_params.json', 'r') as fp:
            checkpoint_params = json.load(fp)
        hidden_size = int(checkpoint_params['hidden_size'])

        # instantiate network
        network = M.NetSV(hidden_size=hidden_size, num_layers=2).to(device)
        sd = torch.load(checkpoint_path, map_location=torch.device(device))
        network.load_state_dict(sd)

        # loop through each speaker
        mean_vectors = {}
        for speaker_id in D.speaker_ids_tr:
            df = D.librispeech.query(f'speaker_id == "{speaker_id}"')
            vectors = []
            for filepath in df.filepath:
                (_, s) = wavfile.read(filepath)
                s = s / np.abs(s).max()
                s = torch.Tensor(s).unsqueeze(0).to(device)
                features = network.embedding(s).squeeze()
                vectors.append(features)
            vectors = torch.stack(vectors)
            mean_vectors[speaker_id] = torch.mean(vectors, dim=0).cpu().numpy()

        np.save(checkpoint_path + '_speakers.npy', mean_vectors)
        print('Finished!')

Finished!


## Do k-means clustering on the speaker vectors, use TSNE for visualization 

In [262]:
figsize = (4, 2)
dpi = 300
seed = 0
plot = False
save_mapping_to_file = True

plt.style.use(['science', 'ieee'])

df = mean_vectors
data = np.vstack(list(df.values()))
labels = np.vstack(list(df.keys())).squeeze()

import yaml

for num_clusters in [2, 3, 4, 5, 6, 7, 8, 9, 10]:

    kmeans = KMeans(n_clusters=num_clusters, random_state=seed).fit(data)
    classes = kmeans.labels_
    mapping_1 = {str(k): int(v) for (k,v) in zip(labels, classes)}
    mapping_2 = reverse_non_unique_mapping(mapping_1)
    
    if save_mapping_to_file:
        with open(checkpoint_path+f'_mapping_k={num_clusters:02d}.yaml', 'w') as fp:
            yaml.dump({'speakers': mapping_1, 'specialists': mapping_2}, fp)
    
    if not plot:
        continue

    tsne = TSNE(n_components=2, verbose=0, perplexity=40, n_iter=300, random_state=seed)
    data_2d = tsne.fit_transform(data)

    fig = plt.figure(figsize=figsize, dpi=dpi)
    plt.xlabel(r'1\textsuperscript{st} Dimension')
    plt.ylabel(r'2\textsuperscript{nd} Dimension')

    # plot by k-means
    for i in range(0, len(labels), 1):
        colors = cm.get_cmap('tab10').colors
        gender = D.speakers_tr[D.speakers_tr.speaker_id==labels[i]].gender.item()
        plt.scatter(data_2d[i, 0], data_2d[i, 1], color=colors[classes[i]], s=32,
                    linewidths=0.2, edgecolors='k', marker={'M': 'P', 'F': 'D'}[gender])

    hd_gender = [
        plt.plot((-100,), (-100,), ls='none', marker='P', markersize=6, c='k',
                        label='Male')[0],
        plt.plot((-100,), (-100,), ls='none', marker='D', markersize=4.5, c='k',
                        label='Female')[0]
    ]
    hd_cluster = [
        plt.plot((-100,), (-100,), ls='none', marker='o', mec='k', 
                 markeredgewidth=0.5,
                 color=colors[i],
                 label=(f'{i+1}'))[0]
        for i in range(0, len(set(classes)))
    ]
    plt.gca().add_artist(plt.legend(handles=hd_gender, ncol=2, columnspacing=1,
                                    handletextpad=0.2, fontsize=8, loc='lower right'))
    plt.gca().add_artist(plt.legend(handles=hd_cluster, ncol=3, columnspacing=1,
                                    handletextpad=0, fontsize=8, loc='upper left'))
    plt.xlim([-15, 15])
    plt.ylim([-15, 15])

    plt.savefig(f'figures/fig_tsne-k{num_clusters:02d}.pdf',
                facecolor='white', transparent=False, bbox_inches='tight')
    plt.show()