In [1]:
from tqdm import tqdm

### load wikititles

In [2]:
# with open("sapbert/training_data/general_domain_parallel_data/en_ru_muse+wikititle_pairs.txt", encoding="utf-8") as f:
#     wikititles = f.readlines()

In [3]:
# wikititles = [x.strip().split("||") for x in wikititles]

In [4]:
# wikititles[0]

['MRU0', 'категория', 'category']

In [5]:
# len(wikititles)

935474

### load wikidata

In [12]:
with open("wikidata-filtered-ru.tsv", encoding="utf-8") as f:
    wikidata = f.readlines()

In [13]:
wikidata = [x.strip().split("\t")[1:] for x in wikidata[1:]]

In [17]:
wikidata[0]

['Q5', 'Человек']

In [18]:
len(wikidata)

2930904

### load sapbert

In [19]:
from transformers import AutoTokenizer, AutoModel
path = "sapbert/train/tmp/xlmr_base_sap_xling_tuned"
tokenizer = AutoTokenizer.from_pretrained(path, local_files_only=True)  
model = AutoModel.from_pretrained(path, local_files_only=True).cuda(0)

In [20]:
import numpy as np
import torch

In [38]:
def encode(data, idx=1):
    bs = 128
    all_reps = []
    for i in tqdm(np.arange(0, len(data), bs)):
        batch = [x[idx] for x in data[i:i+bs]]
        toks = tokenizer.batch_encode_plus(
            batch, padding="max_length", max_length=25, truncation=True, return_tensors="pt"
        )
        toks_cuda = {}
        for k,v in toks.items():
            toks_cuda[k] = v.cuda(0)
        output = model(**toks_cuda)

    #     output = model(**toks)
        cls_rep = output[0][:,0,:]

        all_reps.append(cls_rep.cpu().detach().numpy())
    return np.concatenate(all_reps, axis=0)

#### encode wikidata labels

In [22]:
wikidata_emb = encode(wikidata)

100%|██████████| 22898/22898 [32:14<00:00, 11.84it/s]


In [23]:
print(wikidata_emb.shape)

(2930904, 768)


### tacred EL

In [34]:
import glob

tacred = []
for file in glob.glob("50text_tacred_EL/*.ann"):
    with open(file, encoding="utf-8") as f:
        data = f.readlines()

    entities = {}
    for line in data:
        if "Reference" in line and "Wikidata:" in line:
            tag, wd_id = line.split()[2:4]
            entities[tag] = wd_id.replace("Wikidata:", "")

    for line in data:
        tag, *tmp = line.split("\t")
        if tag[0] == "T" and tag in entities:
            tacred.append((entities[tag], tmp[1].strip()))


tacred = list(set(tacred))

In [35]:
print(len(tacred))

1420


In [36]:
print(tacred[0], tacred[-1])

('Q145', 'Великобритании') ('Q494412', 'Ли Гонхи')


In [39]:
tacred_emb = encode(tacred)
print(tacred_emb.shape)

100%|██████████| 12/12 [00:01<00:00,  8.29it/s]

(1420, 768)





#### encode nerel

In [10]:
# import glob
# data = []
# for file in glob.glob("NEREL/train/*.ann"):
#     with open(file, encoding="utf-8") as f:
#         tmp = f.readlines()
#         tmp = [file + "\t" + x.replace("\n", "") for x in tmp]
#         data.extend(tmp)
# nerel = [x.split("\t") for x in data if x.split("\t")[1][0] == "T"]

In [14]:
# len(nerel)

44627

In [15]:
# nerel[:5]

[['train/27439_text.ann', 'T2', 'PERSON 28 35', 'Норьеги'],
 ['train/27439_text.ann', 'T3', 'LOCATION 58 64', 'Боедик'],
 ['train/27439_text.ann', 'T4', 'LOCATION 66 72', 'Boëdic'],
 ['train/27439_text.ann', 'T5', 'COUNTRY 74 85', 'Французский'],
 ['train/27439_text.ann', 'T6', 'PERSON 94 108', 'Оливье Метцнер']]

In [16]:
# nerel_emb = encode(nerel, idx=3)
# print(nerel_emb.shape)

100%|██████████| 349/349 [00:28<00:00, 12.18it/s]

(44627, 768)





In [None]:
# query = "Доминик де Вильпен"
# query_toks = tokenizer.batch_encode_plus(
#     [query], padding="max_length", max_length=25, truncation=True, return_tensors="pt"
# )

In [None]:
# query_output = model(**query_toks)
# query_cls_rep = query_output[0][:,0,:]

In [None]:
# query_cls_rep.shape

#### find query's nearest neighbour

In [42]:
%%time
# for large-scale search, should switch to faiss
# from scipy.spatial.distance import cdist
import faiss

dim = 768

index = faiss.index_factory(dim, "IVF1000,Flat")
index.train(wikidata_emb)
index.add(wikidata_emb)
print(index.ntotal)

2930904
CPU times: user 2min 22s, sys: 26.8 s, total: 2min 49s
Wall time: 2min 57s


In [44]:
%%time
topn = 1
index.nprobe = 16  # Проходим по топ-16 центроид для поиска top-n ближайших соседей
D, I = index.search(tacred_emb, topn)
I.shape

CPU times: user 49.4 s, sys: 20.2 ms, total: 49.4 s
Wall time: 12.4 s


(1420, 1)

In [None]:
# dist = cdist(query_cls_rep.cpu().detach().numpy(), all_reps_emb)
# nn_index = np.argmin(dist)
# print ("predicted label:", nerel[nn_index])

In [46]:
I[0]

array([70])

In [61]:
result = [str(tacred[i]) + "\tPred: " + str((wikidata[I[i][0]][0], wikidata[I[i][0]][1]))
         for i in range(len(I))]

with open("out.txt", "w", encoding="utf-8") as f:
    for i in range(len(I)):
        f.write(result[i] + "\n")

In [63]:
result[1000]

"('Q1501926', 'генпрокурор')\tPred: ('Q449319', 'Генеральный прокурор Германии')"

In [55]:
y_true = [x[0] for x in tacred]
y_pred = [wikidata[i[0]][0] for i in I]

In [60]:
y_true[1000], y_pred[1000]

('Q1501926', 'Q449319')

In [59]:
sum(t == p for t, p in zip(y_true, y_pred)) / len(y_true)

0.4295774647887324