In [1]:
import numpy as np
import pandas as pd
import torch

from collections import defaultdict

from src.dataset import Bigrammer
from src.dataset_reader import sentence2words_preprocessing

In [2]:
bigrammer, w2v = torch.load("models/embedding.pth")
w2v.to('cpu')

Embedding(102651, 64)

In [3]:
word = 'cat'

In [4]:
v_size = len(bigrammer.word2idx)
m,_ = torch.max(torch.abs(w2v.weight), dim=0)

In [5]:
def get_norm_vec(word):
    id_tensor = torch.LongTensor([bigrammer.word2idx[word]])
    word_vec = w2v(id_tensor)
    return word_vec / m

word_vec = get_norm_vec(word)
print(word_vec.shape ,word_vec)

torch.Size([1, 64]) tensor([[-0.0785,  0.0371,  0.0719, -0.1542,  0.0559,  0.1083, -0.1175, -0.2281,
         -0.1631,  0.0700,  0.0250, -0.0196,  0.1202,  0.1133, -0.0451,  0.2806,
         -0.0919,  0.1375, -0.0587,  0.2352, -0.0828, -0.0337, -0.0672,  0.0778,
         -0.0709, -0.0403,  0.0649,  0.0012, -0.1063, -0.0342, -0.1519,  0.1177,
          0.2297,  0.0981,  0.1296, -0.1640,  0.0391, -0.0658, -0.1415, -0.0076,
          0.0783,  0.0906,  0.1301,  0.1382, -0.2200, -0.0929,  0.0465, -0.0711,
         -0.0004, -0.1556,  0.0417,  0.1611, -0.0250, -0.0372, -0.0395, -0.1294,
          0.0758, -0.0476,  0.0395, -0.0509,  0.0225,  0.0681, -0.1601,  0.0360]],
       grad_fn=<DivBackward0>)


In [6]:
test_set = pd.read_csv("data/google-analogies.csv", index_col=0)
test_set.head()

Unnamed: 0_level_0,type,word1,word2,word3,target
idx,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
0,capital-common-countries,Athens,Greece,Baghdad,Iraq
1,capital-common-countries,Athens,Greece,Bangkok,Thailand
2,capital-common-countries,Athens,Greece,Beijing,China
3,capital-common-countries,Athens,Greece,Berlin,Germany
4,capital-common-countries,Athens,Greece,Bern,Switzerland


In [7]:
columns = test_set.columns
vals = test_set.values
vals.shape

(19544, 5)

In [8]:
pr, not_pr = 0, 0
clean_set = defaultdict(list)
for val in vals:
    cat, words = val[0], val[1:]
    words = [w.lower() for w in words]
    # check if all are present
    if all([w in bigrammer.word2idx for w in words]):
        pr += 1
        clean_set[cat].append(words)
    else:
        not_pr += 1

print("Test cases: {} present / {} not present".format(pr, not_pr))

Test cases: 16600 present / 2944 not present


# 3CosAdd by categories

In [9]:
def cos(a,b):
    a = a.flatten()
    b = b.flatten()
    return a @ b / (a.norm() * b.norm())


a = torch.tensor([1.0, 0.0]).view(1,2).float()
b = torch.tensor([1.0, 0.5]).view(1,2).float()

cos(a,b)

tensor(0.8944)

In [10]:
distances = {} # category -> list of cosine dists
for cat, samples in clean_set.items():
    distances[cat] = []
    for case in samples:
        # 1. get all 4 vectors:
        vecs = [get_norm_vec(w) for w in case]
        # 2. calculate distance to target (case[3])
        target = vecs[3]
        destination = vecs[2] + (vecs[1] - vecs[0])
        distances[cat].append(abs(cos(target, destination).item()))

In [11]:
all_dists = []
cats_mean_dists = []
def print_stat(metrics):
    print("\t max:\t{:.3f}".format(metrics[0]))
    print("\t mean:\t{:.3f}".format(metrics[1]))
    print("\t std:\t{:.3f}".format(metrics[2]))
    
for cat, cs in distances.items():
    print(cat)
    metrics = (np.max(cs), np.mean(cs), np.std(cs))
    print_stat(metrics)
    all_dists += cs
    cats_mean_dists.append(metrics)

capital-common-countries
	 max:	0.737
	 mean:	0.363
	 std:	0.172
capital-world
	 max:	0.737
	 mean:	0.222
	 std:	0.138
currency
	 max:	0.462
	 mean:	0.119
	 std:	0.093
city-in-state
	 max:	0.756
	 mean:	0.332
	 std:	0.144
family
	 max:	0.927
	 mean:	0.409
	 std:	0.221
gram1-adjective-to-adverb
	 max:	0.689
	 mean:	0.266
	 std:	0.174
gram2-opposite
	 max:	0.602
	 mean:	0.178
	 std:	0.121
gram3-comparative
	 max:	0.792
	 mean:	0.349
	 std:	0.191
gram4-superlative
	 max:	0.755
	 mean:	0.259
	 std:	0.181
gram5-present-participle
	 max:	0.780
	 mean:	0.322
	 std:	0.172
gram6-nationality-adjective
	 max:	0.761
	 mean:	0.345
	 std:	0.190
gram7-past-tense
	 max:	0.771
	 mean:	0.346
	 std:	0.185
gram8-plural
	 max:	0.803
	 mean:	0.294
	 std:	0.183
gram9-plural-verbs
	 max:	0.752
	 mean:	0.301
	 std:	0.179


In [12]:
print("Mean metrics between categories:")
metrics = np.array(cats_mean_dists).mean(axis=0)
print_stat(metrics)

Mean metrics between categories:
	 max:	0.737
	 mean:	0.293
	 std:	0.167


In [13]:
print("For all test set:")
cs = all_dists
print_stat((np.max(cs), np.mean(cs), np.std(cs)))

For all test set:
	 max:	0.927
	 mean:	0.296
	 std:	0.179
