In [1]:
!pip install fasttext
!pip install torchtext



In [2]:
from torchtext.vocab import FastText

embedding = FastText('en')
embedding['hello'].numpy()

array([-1.5945e-01, -1.8259e-01,  3.3443e-02,  1.8813e-01, -6.7903e-02,
       -1.3663e-01, -2.5559e-01,  1.1000e-01,  1.7275e-01,  5.1971e-02,
       -2.3302e-02,  3.8866e-02, -2.4515e-01, -2.1588e-01,  3.5925e-01,
       -8.2526e-02,  1.2176e-01, -2.6775e-01,  1.0072e-01, -1.3639e-01,
       -9.2658e-02,  5.1837e-01,  1.7736e-01,  9.4878e-02, -1.8461e-01,
       -4.2829e-02,  1.4114e-02,  1.6811e-01, -1.8565e-01,  3.4976e-02,
       -1.0293e-01,  1.7954e-01, -5.2766e-02,  7.2047e-02, -4.2704e-01,
       -1.1616e-01, -9.4875e-03,  1.4199e-01, -2.2782e-01, -1.7292e-02,
        8.2802e-02, -4.4512e-01, -7.5935e-02, -1.4392e-01, -8.2461e-02,
        2.0123e-01, -9.5344e-02, -1.1042e-01, -4.6817e-01,  2.0362e-01,
       -1.7140e-01, -4.9850e-01,  2.8963e-01, -1.0305e-01,  2.0393e-01,
        5.2971e-01, -2.5396e-01, -5.1891e-01,  2.9941e-01,  1.7933e-01,
        3.0683e-01,  2.5828e-01, -1.8168e-01, -1.0225e-01, -1.1435e-01,
       -1.6304e-01, -1.2424e-01,  3.2814e-01, -2.3099e-01,  1.79

In [3]:
import numpy as np
import pandas as pd
import collections
from tqdm import tqdm

train_texts = list(pd.read_csv('/home/mlepekhin/data/en_train').text.values)
test_texts = list(pd.read_csv('/home/mlepekhin/data/en_test').text.values)

In [4]:
!mkdir en_fasttext_30000
!mkdir en_fasttext_50000
!mkdir en_fasttext_all

mkdir: cannot create directory ‘en_fasttext_30000’: File exists
mkdir: cannot create directory ‘en_fasttext_50000’: File exists
mkdir: cannot create directory ‘en_fasttext_all’: File exists


In [5]:
target_tokens = [token for text in train_texts + test_texts for token in text.lower().split()]
token_count = collections.defaultdict(int)
for token in target_tokens:
    token_count[token] += 1
token_count = sorted(token_count.items(), key=lambda item: (-item[1], item[0]))[:50000]
token_count_dict = dict(token_count)

In [6]:
def is_russian_word(s):
    return all(['a' <= ch <= 'z' for ch in s.lower()])

In [7]:
final_vocab = [pair[0] for pair in token_count\
               if is_russian_word(pair[0]) and np.linalg.norm(embedding[pair[0]].numpy()) > 0]
word2index = {word: index for index, word in enumerate(final_vocab)}
index2word = {index: word for index, word in enumerate(final_vocab)}
embedding_matrix = [embedding[word].numpy() for word in final_vocab]
embedding_matrix = np.array([vec / (np.linalg.norm(vec) if np.linalg.norm(vec) > 0.0001 else 1.0)\
                    for vec in embedding_matrix])

In [8]:
print(len(final_vocab))

38555


In [9]:
import faiss

In [10]:
dim = 300
k = 50
cluster_num = 1000  # количество “командиров”

quantiser = faiss.IndexFlatL2(dim) 
index = faiss.IndexIVFFlat(quantiser, dim, cluster_num)
index.nprobe = 16 

In [11]:
index.train(embedding_matrix)
index.add(embedding_matrix)

In [12]:
D, I = index.search(embedding_matrix, k) 
print(I)
print(D)

[[    0     1    28 ...   278   113    25]
 [    1     0     2 ...  3741   374  3073]
 [    2   145    96 ...   548  7541   578]
 ...
 [38552  8733 23990 ... 10692  9709 27710]
 [38553 38355 31416 ... 15048 23741 22595]
 [38554 37747 34412 ... 26182 22978 10672]]
[[0.         0.4730763  0.672725   ... 1.0767486  1.0774659  1.0783682 ]
 [0.         0.4730763  0.7981688  ... 1.2046138  1.2053785  1.2070303 ]
 [0.         0.5823202  0.698479   ... 1.0188984  1.0192317  1.0207292 ]
 ...
 [0.         0.66974187 0.7857368  ... 0.93739283 0.94024545 0.9444096 ]
 [0.         0.9490695  0.9848088  ... 1.1236608  1.1236701  1.1242998 ]
 [0.         0.87755144 0.9235295  ... 1.1002364  1.1024427  1.1036191 ]]


In [13]:
import pickle

np.save('en_fasttext_50000/nn_matrix.npy', I)
np.save('en_fasttext_50000/embeddings_matrix.npy', embedding_matrix)
pickle.dump(word2index, open('en_fasttext_50000/word2index.pcl', 'wb'))
pickle.dump(index2word, open('en_fasttext_50000/index2word.pcl', 'wb'))

In [14]:
?pickle.dump

In [21]:
np.linalg.norm(np.array([1, 1]) / np.linalg.norm([1, 1]))

0.9999999999999999