In [None]:
import os
import sys
import seaborn as sns
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
from collections import OrderedDict

# load other modules --> repo root path
sys.path.insert(0, "../")

import torch
from utils import text, audio
from utils import build_model
from params.params import Params as hp
from modules.tacotron2 import Tacotron

In [None]:
def encode(model, inputs):
    
    inputs = [l.rstrip().split('|') for l in inputs if l]
    encodeds = []

    for i in inputs:
        t = torch.LongTensor(text.to_sequence(i[0], use_phonemes=hp.use_phonemes))
        l = torch.LongTensor([hp.languages.index(i[2])]) if hp.multi_language else None

        if torch.cuda.is_available(): 
            t = t.cuda(non_blocking=True)
            if l: l = l.cuda(non_blocking=True)
            if s: s = s.cuda(non_blocking=True)

        t.unsqueeze_(0)
        embedded = model._embedding(t)
        encoded = model._encoder(embedded, torch.LongTensor([t.size(1)]), l)
        
        unique_chars = list(set(i[0]))
        char_ids = [unique_chars.index(x) for x in i[0]]
        encodeds.append((i[0], char_ids, encoded.squeeze(0).cpu().detach().numpy()[:-1, :]))
    
    return encodeds

In [None]:
checkpoint = ""

In [None]:
torch.load(checkpoint, map_location="cpu")['parameters']

In [None]:
model = build_model(checkpoint)
model.eval();
print(hp.encoder_type)

# Encoded output

In [None]:
inputs = ["erlauben sie bitte, dass ich mich kurz vorstelle. ich heiße jana novakova.||fr",
          "erlauben sie bitte, dass ich mich kurz vorstelle. ich heiße jana novakova.||de",
          "les socialistes et les républicains sont venus apporter leurs voix à la majorité pour ce texte.||fr",
          "les socialistes et les républicains sont venus apporter leurs voix à la majorité pour ce texte.||de"]

In [None]:
embeddings = encode(model, inputs[:-2])

In [None]:
tsne = [(t, c, TSNE(n_components=2).fit_transform(e)) for (t, c, e) in embeddings]

In [None]:
fig = plt.figure(figsize=(19, 25))
for i, (t, c, e) in enumerate(tsne):
    ax = plt.subplot(4, 3, i + 1)   
    for j in range(len(t)):
        plt.scatter(e[j, 0], e[j, 1], c='k', marker=r"$ {} $".format(t[j].replace(" ", "/")), alpha=0.7, s=50)
        #ax.set_title(f'pool_layer={i + 1}')
plt.tight_layout()
plt.subplots_adjust(bottom=0.1, right=0.95, top=0.9)
#cax = plt.axes([0.96, 0.6, 0.02, 0.3])
#cbar = plt.colorbar(cax=cax, ticks=range(len(texts)))

#cbar.ax.get_yaxis().set_ticks([])
#for j, lab in enumerate(texts.keys()):
#    cbar.ax.text(4.25, (2 * j + 1) / 2.25, lab, ha='center', va='center', fontsize=15)

plt.show() 

# Language embedding

In [None]:
def show_similarity(similarity, languages):
    fig = plt.figure(figsize=(6, 6))
    ax = fig.add_subplot(111)
    diagonal = similarity[1:,:-1].copy()
    lower_indices = np.tril_indices(diagonal.shape[0])
    lower = diagonal[lower_indices]
    lower_min = np.min(lower)
    lower_max = np.max(lower)
    diagonal = (diagonal - lower_min) / (lower_max - lower_min)
    cax = ax.matshow(np.tril(diagonal), interpolation='nearest')
    fig.colorbar(cax)
    ax.set_xticklabels([languages[0]]+languages[:-1], rotation='vertical')
    ax.set_yticklabels(languages)
    plt.show()

In [None]:
embeddings = model._encoder._embedding.weight.detach().numpy()

In [None]:
from sklearn.metrics.pairwise import cosine_similarity
import numpy as np

In [None]:
similarity = cosine_similarity(embeddings)

In [None]:
show_similarity(similarity, hp.languages)

In [None]:
layer_number = 2
layer_weights = model._encoder._layers[layer_number]._convolution._bottleneck.weight.detach().numpy()
bottleneck_embeddings = embeddings @ layer_weights.T

In [None]:
bottleneck_similarity = cosine_similarity(bottleneck_embeddings)

In [None]:
show_similarity(bottleneck_similarity, hp.languages)