In [19]:
% matplotlib inline

import numpy as np
import matplotlib.pyplot as plt

import os
import pickle as pkl

from utils import sqrtm, bures, MI_dist, MK_dist, KR_dist, fidelity, MI_fidelity, MK_fidelity

In [7]:
PATH_ELL = "/Users/boris/Documents/These/Donnees/Elliptical_Embeddings"

### Data loading

In [12]:
data = np.load(os.path.join(PATH_ELL, "embeddings_5"), allow_pickle = True)

means = data["c_means"]
vars = data["c_vars"]

In [13]:
with open(os.path.join(PATH_ELL, 'words_to_idxs.pkl'), 'rb') as f:
    u = pkl._Unpickler(f)
    u.encoding = 'latin1'
    words_to_idxs = u.load()

In [14]:
with open(os.path.join(PATH_ELL, 'vocab_words.pkl'), 'rb') as f:
    u = pkl._Unpickler(f)
    u.encoding = 'latin1'
    idxs_to_words_dict = u.load()
    
idxs_to_words = np.empty(len(data["c_means"]), dtype = object)

for i in range(len(means)):
    idxs_to_words[i] = idxs_to_words_dict[i]

In [18]:
### Let's restrict our vocabulary to the 10K most frequent words ###
n = 30000

means_30K = means[1:n+1]
vars_30K = vars[1:n+1]

idxs_to_words_30K = idxs_to_words[1:n+1]

In [16]:
del data

### KNN functions

In [34]:
def KNN(word, means, vars, k=10, metric='bures', use_means = True,
        renorm = True, words_to_idxs = words_to_idxs,
        idxs_to_words = idxs_to_words_30K):
  
    n = len(means)
    widx = words_to_idxs[word] - 1
    
    if renorm:
        vars_renorm = vars / np.sqrt(np.trace(vars, axis1=1, axis2=2))[:, None, None]
        means_renorm = means / np.sqrt((means**2).sum(axis = 1))[:, None]
    else:
        vars_renorm = vars
        means_renorm = means

    mu = means_renorm[widx]
    sigma = vars_renorm[widx]

    if metric == 'bures':
        dists = np.array([((mu - means_renorm[i])**2).sum() * use_means + bures(sigma, vars_renorm[i]) for i in range(n)])
    idxs = np.argsort(dists)[:k]

    #return(idxs_to_words[idxs+1])
    return(idxs_to_words[idxs])
        
def mediated_KNN(word, reference, means, vars, k=20, d = 4, metric='MK', use_means = False, 
                 renorm = True, words_to_idxs = words_to_idxs,
                 idxs_to_words = idxs_to_words_30K):
    
    n = len(means)
    widx = words_to_idxs[word] - 1
    refidx = words_to_idxs[reference] - 1
    
    if renorm:
        vars_renorm = vars / np.sqrt(np.trace(vars, axis1=1, axis2=2))[:, None, None]
        means_renorm = means / np.sqrt((means**2).sum(axis = 1))[:, None]
    else:
        vars_renorm = vars
        means_renorm = means
    
    mu = means_renorm[widx]
    sigma = vars_renorm[widx]
    
    mu_ref = means_renorm[refidx]
    sigma_ref = vars_renorm[refidx]
    
    vals, vecs = np.linalg.eigh(sigma_ref)
    M = vecs[:, ::-1]
    
    proj_vars = np.array([M.T.dot(vars_renorm[i]).dot(M) for i in range(n)])
    proj_sigma = M.T.dot(sigma).dot(M)
    
    if metric == 'MI':
        dists = np.array([((mu - means_renorm[i])**2).sum() * use_means + MI_dist(proj_sigma, proj_vars[i], d) for i in range(n)])
    elif metric == 'MK':
        dists = np.array([((mu - means_renorm[i])**2).sum() * use_means + MK_dist(proj_sigma, proj_vars[i], d) for i in range(n)])
    
    idxs = np.argsort(dists)[:k]
    
    #return(idxs_to_words[idxs+1])
    return(idxs_to_words[idxs])

def sym_diff(word, ref1, ref2, means, vars, k=20, d = 4, metric='MK', use_means = False, 
                 renorm = True, words_to_idxs = words_to_idxs,
                 idxs_to_words = idxs_to_words_30K):
    
    med_knn_1 = mediated_KNN(word, ref1, means, vars, k, d, metric, use_means, \
                             renorm, words_to_idxs, idxs_to_words)
    med_knn_2 = mediated_KNN(word, ref2, means, vars, k, d, metric, use_means, \
                             renorm, words_to_idxs, idxs_to_words)

    return [w for w in med_knn_1 if w not in med_knn_2], [w for w in med_knn_2 if w not in med_knn_1]
    

### Experiments

In [35]:
mediated_knn = (mediated_KNN('instrument', 'monitor', means_30K, vars_30K, k = 20, d = 4, metric = 'MK'))
print(mediated_knn)

['instrument' 'instruments' 'chromatic' 'harmonics' 'cathode' 'tuning'
 'monitor' 'tonal' 'sampler' 'rca' 'watts' 'amps' 'instrumentation'
 'synthesizers' 'synthesizer' 'telescope' 'resonant' 'harpsichord'
 'tungsten' 'ambient']


In [36]:
s1, s2 = sym_diff('instrument', 'monitor', 'oboe', means_30K, vars_30K )

print("Ref #1 - Ref #2:\n")

for w in s1:
    print(w)

print("\nRef #2 - Ref #1:\n")
    
for w in s2:
    print(w)

Ref #1 - Ref #2:

cathode
monitor
sampler
rca
watts
instrumentation
synthesizers
synthesizer
telescope
ambient

Ref #2 - Ref #1:

tuned
trombone
guitar
harmonic
octave
baritone
clarinet
compositional
saxophone
virtuoso


In [37]:
s1, s2 = sym_diff('windows', 'pc', 'door', means_30K, vars_30K )

print("Ref #1 - Ref #2:\n")

for w in s1:
    print(w)

print("\nRef #2 - Ref #1:\n")
    
for w in s2:
    print(w)

Ref #1 - Ref #2:

netscape
installer
doubleclick
burner
installs
adapter
router
cpus

Ref #2 - Ref #1:

screwed
recessed
rails
ceilings
tiling
upvc
profiled
roofs


In [38]:
s1, s2 = sym_diff('fox', 'media', 'hedgehog', means_30K, vars_30K )

print("Ref #1 - Ref #2:\n")

for w in s1:
    print(w)

print("\nRef #2 - Ref #1:\n")
    
for w in s2:
    print(w)

Ref #1 - Ref #2:

penny
quiz
whitman
outraged
tinker
ads
keating
palin
show

Ref #2 - Ref #1:

panther
reintroduced
kangaroo
harriet
fair
hedgehog
bush
paw
bunny
