In [2]:
from tqdm import tqdm

### load wikititles

In [2]:
# with open("sapbert/training_data/general_domain_parallel_data/en_ru_muse+wikititle_pairs.txt", encoding="utf-8") as f:
#     wikititles = f.readlines()

In [3]:
# wikititles = [x.strip().split("||") for x in wikititles]

In [4]:
# wikititles[0]

['MRU0', 'категория', 'category']

In [5]:
# len(wikititles)

935474

### load sapbert

In [3]:
from transformers import AutoTokenizer, AutoModel
path = "sapbert/train/tmp/xlmr_base_sap_xling_tuned"
tokenizer = AutoTokenizer.from_pretrained(path, local_files_only=True)  
model = AutoModel.from_pretrained(path, local_files_only=True).cuda(0)

In [4]:
import numpy as np
import torch

In [5]:
def encode(data, idx=1):
    bs = 128
    all_reps = []
    for i in tqdm(np.arange(0, len(data), bs)):
        batch = [x[idx] for x in data[i:i+bs]]
        toks = tokenizer.batch_encode_plus(
            batch, padding="max_length", max_length=25, truncation=True, return_tensors="pt"
        )
        toks_cuda = {}
        for k,v in toks.items():
            toks_cuda[k] = v.cuda(0)
        output = model(**toks_cuda)

    #     output = model(**toks)
        cls_rep = output[0][:,0,:]

        all_reps.append(cls_rep.cpu().detach().numpy())
    return np.concatenate(all_reps, axis=0)

### load wikidata

In [6]:
with open("wikidata-filtered-ru.tsv", encoding="utf-8") as f:
    wikidata = f.readlines()

In [7]:
wikidata = [x.strip().split("\t")[1:] for x in wikidata[1:]]

In [8]:
wikidata[0]

['Q5', 'Человек']

In [9]:
len(wikidata)

2930904

In [19]:
with open("query-5.csv", encoding="utf-8") as f:
    disambiguations = f.readlines()

disambiguations = [x.split("entity/")[1].strip() for x in disambiguations[1:]]
disambiguations = set(disambiguations)
len(disambiguations)

1361733

In [21]:
wikidata = [x for x in wikidata if x[0] not in disambiguations]
len(wikidata)

2741382

In [22]:
wikidata_emb = encode(wikidata)

100%|██████████| 21418/21418 [30:09<00:00, 11.84it/s]


In [23]:
print(wikidata_emb.shape)

(2741382, 768)


### tacred EL

In [24]:
import glob

tacred = []
for file in glob.glob("50text_tacred_EL/*.ann"):
    with open(file, encoding="utf-8") as f:
        data = f.readlines()

    entities = {}
    for line in data:
        if "Reference" in line and "Wikidata:" in line:
            tag, wd_id = line.split()[2:4]
            entities[tag] = wd_id.replace("Wikidata:", "").strip()

    for line in data:
        tag, *tmp = line.split("\t")
        if tag[0] == "T" and tag in entities and entities[tag] != "NULL":
            tacred.append((entities[tag], tmp[1].strip(), tmp[0].split()[0]))


tacred = list(set(tacred))

In [25]:
print(len(tacred))

1249


In [26]:
print(tacred[0], tacred[-1])

('Q193391', 'Дипломат', 'PROFESSION') ('Q102673', 'Газпром', 'ORGANIZATION')


In [28]:
tacred_emb = encode(tacred)
print(tacred_emb.shape)

100%|██████████| 10/10 [00:01<00:00,  7.75it/s]

(1249, 768)





#### encode nerel

In [10]:
# import glob
# data = []
# for file in glob.glob("NEREL/train/*.ann"):
#     with open(file, encoding="utf-8") as f:
#         tmp = f.readlines()
#         tmp = [file + "\t" + x.replace("\n", "") for x in tmp]
#         data.extend(tmp)
# nerel = [x.split("\t") for x in data if x.split("\t")[1][0] == "T"]

In [14]:
# len(nerel)

44627

In [15]:
# nerel[:5]

[['train/27439_text.ann', 'T2', 'PERSON 28 35', 'Норьеги'],
 ['train/27439_text.ann', 'T3', 'LOCATION 58 64', 'Боедик'],
 ['train/27439_text.ann', 'T4', 'LOCATION 66 72', 'Boëdic'],
 ['train/27439_text.ann', 'T5', 'COUNTRY 74 85', 'Французский'],
 ['train/27439_text.ann', 'T6', 'PERSON 94 108', 'Оливье Метцнер']]

In [16]:
# nerel_emb = encode(nerel, idx=3)
# print(nerel_emb.shape)

100%|██████████| 349/349 [00:28<00:00, 12.18it/s]

(44627, 768)





In [None]:
# query = "Доминик де Вильпен"
# query_toks = tokenizer.batch_encode_plus(
#     [query], padding="max_length", max_length=25, truncation=True, return_tensors="pt"
# )

In [None]:
# query_output = model(**query_toks)
# query_cls_rep = query_output[0][:,0,:]

In [None]:
# query_cls_rep.shape

### top1 accuracy

In [29]:
%%time
import faiss

dim = 768

index = faiss.index_factory(dim, "IVF1000,Flat")
index.train(wikidata_emb)
index.add(wikidata_emb)
print(index.ntotal)

2741382
CPU times: user 2min 17s, sys: 20.7 s, total: 2min 38s
Wall time: 2min 25s


In [30]:
%%time
topn = 1
index.nprobe = 16  # Проходим по топ-16 центроид для поиска top-n ближайших соседей
D, I = index.search(tacred_emb, topn)
I.shape

CPU times: user 40.4 s, sys: 45.8 ms, total: 40.5 s
Wall time: 10.3 s


(1249, 1)

In [31]:
result = [
    str(tacred[i]) + "\tPred: " + str((wikidata[I[i][0]][0], wikidata[I[i][0]][1]))
    for i in range(len(I))
]

with open("out.txt", "w", encoding="utf-8") as f:
    for i in range(len(I)):
        f.write(result[i] + "\n")

In [32]:
result[1000]

"('Q2915012', 'Тель ха-Шомер', 'FACILITY')\tPred: ('Q2915012', 'Тель ха-Шомер')"

In [33]:
y_true = [x[0] for x in tacred]
y_pred = [wikidata[i[0]][0] for i in I]
acc = sum(t == p for t, p in zip(y_true, y_pred)) / len(y_true)
print("Top1 accuracy: ", acc)

Top1 accuracy:  0.5036028823058447


### top3 & top5 accuracy

In [34]:
topn = 3
index.nprobe = 16  # Проходим по топ-16 центроид для поиска top-n ближайших соседей
D, I = index.search(tacred_emb, topn)
y_true = [x[0] for x in tacred]
y_pred = [[wikidata[i][0] for i in ii] for ii in I]
acc = sum(t in p for t, p in zip(y_true, y_pred)) / len(y_true)
print("Top3 accuracy: ", acc)

Top3 accuracy:  0.5876701361088871


In [35]:
topn = 5
index.nprobe = 16  # Проходим по топ-16 центроид для поиска top-n ближайших соседей
D, I = index.search(tacred_emb, topn)
I.shape
y_true = [x[0] for x in tacred]
y_pred = [[wikidata[i][0] for i in ii] for ii in I]
acc = sum(t in p for t, p in zip(y_true, y_pred)) / len(y_true)
print("Top5 accuracy: ", acc)

Top5 accuracy:  0.6156925540432346


#### accuracy by groups

In [36]:
groups = set(x[2] for x in tacred)
print(groups)

{'STATE_OR_PROVINCE', 'WORK_OF_ART', 'LAW', 'NATIONALITY', 'ORGANIZATION', 'RELIGION', 'LOCATION', 'FACILITY', 'LANGUAGE', 'AWARD', 'CITY', 'DISTRICT', 'COUNTRY', 'PRODUCT', 'PERSON', 'PROFESSION'}


In [37]:
accuracy = {}
for g in groups:
    ii = [i for i, x in enumerate(tacred) if x[2] == g]
    y_true = [tacred[i][0] for i in ii]
    y_pred = [wikidata[I[i][0]][0] for i in ii]
    accuracy[g] = sum(t in p for t, p in zip(y_true, y_pred)) / len(y_true)
    print(f"{g}:\t{accuracy[g]:.3f}")

STATE_OR_PROVINCE:	0.429
WORK_OF_ART:	0.455
LAW:	1.000
NATIONALITY:	0.250
ORGANIZATION:	0.514
RELIGION:	0.667
LOCATION:	0.312
FACILITY:	0.524
LANGUAGE:	0.250
AWARD:	0.583
CITY:	0.282
DISTRICT:	0.500
COUNTRY:	0.301
PRODUCT:	0.667
PERSON:	0.570
PROFESSION:	0.274


In [38]:
with open("out.txt", "r", encoding="utf-8") as f:
    a = f.readlines()

In [39]:
a[15]

"('Q27832358', 'Верховный комиссар ООН по делам беженцев', 'PROFESSION')\tPred: ('Q132551', 'Верховный комиссариат ООН по делам беженцев')\n"