In [1]:
import glob
import pickle
from datetime import datetime
from tqdm import tqdm

import faiss
import numpy as np
import torch
from transformers import AutoTokenizer, AutoModel

In [2]:
!cat sapbert/train/generalized_wikidata_train.sh

cat: sapbert/train/generalized_wikidata_train.sh: No such file or directory


### load sapbert

In [3]:
experiment = "sapbert_xlm_roberta_base"
path = "xlm-roberta-base"
tokenizer = AutoTokenizer.from_pretrained(path)  
model = AutoModel.from_pretrained(path).cuda(0)

In [2]:
experiment = "sapbert_wikidata_collated"
path = "experiments/sapbert_wikidata_collated/model"
tokenizer = AutoTokenizer.from_pretrained(path, local_files_only=True)  
model = AutoModel.from_pretrained(path, local_files_only=True).cuda(0)

In [3]:
experiment = "sapbert_wikititles"
path = "experiments/sapbert_wikititles/model"
tokenizer = AutoTokenizer.from_pretrained(path, local_files_only=True)  
model = AutoModel.from_pretrained(path, local_files_only=True).cuda(0)

In [4]:
def encode(data, idx=1):
    bs = 128
    all_reps = []
    for i in tqdm(np.arange(0, len(data), bs)):
        batch = [x[idx] for x in data[i:i+bs]]
        toks = tokenizer.batch_encode_plus(
            batch, padding="max_length", max_length=25, truncation=True, return_tensors="pt"
        )
        toks_cuda = {}
        for k,v in toks.items():
            toks_cuda[k] = v.cuda(0)
        output = model(**toks_cuda)

    #     output = model(**toks)
        cls_rep = output[0][:,0,:]

        all_reps.append(cls_rep.cpu().detach().numpy())
    return np.concatenate(all_reps, axis=0)

### load wikidata

In [5]:
with open("data/wikidata_disumbiguation.csv", encoding="utf-8") as f:
    disambiguations = f.readlines()

disambiguations = [x.split("entity/")[1].strip() for x in disambiguations[1:]]
disambiguations = set(disambiguations)

In [6]:
wikidata_type = "wikidata-filtered"
with open("data/wikidata-filtered-ru.tsv", encoding="utf-8") as f:
    wikidata = f.readlines()
wikidata = [x.strip().split("\t")[1:] for x in wikidata[1:]]
wikidata = [x for x in wikidata if x[0] not in disambiguations]

In [6]:
wikidata_type = "wikidata"
with open("data/wikidata.txt", encoding="utf-8") as f:
    wikidata = f.readlines()
wikidata = [x.strip().split("||")[:2] for x in wikidata]
wikidata = [x for x in wikidata if x[0] not in disambiguations]

In [7]:
wikidata_emb = encode(wikidata)

100%|██████████| 41674/41674 [57:18<00:00, 12.12it/s] 


In [8]:
print(wikidata_type)

wikidata


In [9]:
# with open(f"experiments/{experiment}/results/{wikidata_type}.pickle", "wb") as f:
#     pickle.dump(wikidata_emb, f, protocol=pickle.HIGHEST_PROTOCOL)

In [10]:
# with open(f"experiments/{experiment}/results/{wikidata_type}.pickle", "rb") as f:
#     wikidata_emb = pickle.load(f)

### tacred 2411

In [9]:
tacred = []
for file in glob.glob("data/tacred_2411/*/*.ann"):
    with open(file, encoding="utf-8") as f:
        data = f.readlines()

    entities = {}
    for line in data:
        if "Reference" in line and "Wikidata:" in line:
            tag, wd_id = line.split()[2:4]
            entities[tag] = wd_id.replace("Wikidata:", "").strip()

    for line in data:
        tag, *tmp = line.split("\t")
        if tag[0] == "T" and tag in entities:
            tacred.append((entities[tag], tmp[1].strip(), tmp[0].split()[0]))


tacred = list(tacred)

In [10]:
len(tacred)

29962

In [11]:
sum(x[0] == "NULL" for x in tacred) / len(tacred)

0.15886789933916293

In [12]:
tacred_emb = encode(tacred)

100%|██████████| 235/235 [00:19<00:00, 12.16it/s]


#### encode nerel

In [13]:
# import glob
# data = []
# for file in glob.glob("NEREL/train/*.ann"):
#     with open(file, encoding="utf-8") as f:
#         tmp = f.readlines()
#         tmp = [file + "\t" + x.replace("\n", "") for x in tmp]
#         data.extend(tmp)
# nerel = [x.split("\t") for x in data if x.split("\t")[1][0] == "T"]

In [14]:
# len(nerel)

In [15]:
# nerel[:5]

In [16]:
# nerel_emb = encode(nerel, idx=3)
# print(nerel_emb.shape)

In [17]:
# query = "Доминик де Вильпен"
# query_toks = tokenizer.batch_encode_plus(
#     [query], padding="max_length", max_length=25, truncation=True, return_tensors="pt"
# )

In [18]:
# query_output = model(**query_toks)
# query_cls_rep = query_output[0][:,0,:]

In [19]:
# query_cls_rep.shape

### top1 accuracy

In [20]:
wikidata_emb.shape

(5334165, 768)

In [21]:
tacred_emb.shape

(29962, 768)

In [22]:
%%time
print(datetime.now())
dim = 768

index = faiss.index_factory(dim, "IVF65536_HNSW32,Flat")

res = faiss.StandardGpuResources()
index_ivf = faiss.extract_index_ivf(index)
index_flat = faiss.IndexFlatL2(768)
clustering_index = faiss.index_cpu_to_gpu(res, 0, index_flat)  #  0 – номер GPU
index_ivf.clustering_index = clustering_index

index.train(wikidata_emb)
index.add(wikidata_emb)
print(datetime.now())

2021-12-02 15:14:15.594703
2021-12-02 15:31:11.942504
CPU times: user 1h 7min 10s, sys: 16.3 s, total: 1h 7min 26s
Wall time: 16min 56s


In [23]:
%%time
print(datetime.now())
topn = 5
index.nprobe = 16  # Проходим по топ-16 центроид для поиска top-n ближайших соседей
D, I = index.search(tacred_emb, topn)
I.shape
print(datetime.now())

2021-12-02 15:31:12.206713
2021-12-02 15:31:20.086462
CPU times: user 52.8 s, sys: 360 ms, total: 53.2 s
Wall time: 7.88 s


In [24]:
result = [
    str(tacred[i]) + "\tPred: " + str((wikidata[I[i][0]][0], wikidata[I[i][0]][1]))
    for i in range(len(I))
]

with open(f"experiments/{experiment}/results/out-{wikidata_type}.txt", "w", encoding="utf-8") as f:
    for i in range(len(I)):
        f.write(result[i] + "\n")

### accuracy

In [25]:
THRESH = 220

In [26]:
def accuracy(index, null_thresh=None):
    y_true = [tacred[i][0] for i in index]
    y_pred = [[wikidata[pred][0] if null_thresh is None or dist < null_thresh else "NULL"
               for pred, dist in zip(I[i], D[i])]
              for i in index]
    
    top1 = sum(t in p[:1] for t, p in zip(y_true, y_pred)) / len(y_true)
    top3 = sum(t in p[:3] for t, p in zip(y_true, y_pred)) / len(y_true)
    top5 = sum(t in p[:5] for t, p in zip(y_true, y_pred)) / len(y_true)
    null = sum("NULL" in p[:1] for t, p in zip(y_true, y_pred))
    return top1, top3, top5, null

In [27]:
# groups = set(x[2] for x in tacred)
# groups.remove(l)
groups = set(['AWARD', 'CITY', 'COUNTRY', 'DISTRICT','FACILITY', 'LANGUAGE', 'LAW', 'LOCATION', 'NATIONALITY', 
          'ORGANIZATION','PERSON', 'PRODUCT', 'PROFESSION', 'RELIGION', 'STATE_OR_PROVINCE', 'WORK_OF_ART'])

result = f"Tacred Accuracy on {wikidata_type}\n"
result += f"NULL Treshold {THRESH}\n"
result += "Top1\tT1+Null\tTop3\tT3+Null\tTop5\tT5+Null\tNULL\tN\tGroup\n"
group_res = []
for g in sorted(groups):
    group_idx = [i for i, x in enumerate(tacred) if x[2] == g]
    t1n, t3n, t5n, null = accuracy(group_idx, null_thresh=THRESH)
    t1, t3, t5, _ = accuracy(group_idx)
    
    group_res.append((t1, t1n, t3, t3n, t5, t5n))
    result += f"{t1:.3f}\t{t1n:.3f}\t{t3:.3f}\t{t3n:.3f}\t{t5:.3f}\t{t5n:.3f}\t{null}\t{len(group_idx)}\t{g}\n"

result += "\n"
t1n, t3n, t5n, null = accuracy(range(len(tacred)), null_thresh=THRESH)
t1, t3, t5, _ = accuracy(range(len(tacred)))
result += f"{t1:.3f}\t{t1n:.3f}\t{t3:.3f}\t{t3n:.3f}\t{t5:.3f}\t{t5n:.3f}\t{null}\t{len(tacred)}\tTotal (Micro)\n"

t1, t1n, t3, t3n, t5, t5n = [sum(x[i] for x in group_res) / len(group_res) for i in range(6)]
result += f"{t1:.3f}\t{t1n:.3f}\t{t3:.3f}\t{t3n:.3f}\t{t5:.3f}\t{t5n:.3f}\t\t\tTotal (Macro)"
print(result)
with open(f"experiments/{experiment}/results/accuracy-{wikidata_type}.txt", "w", encoding="utf-8") as f:
    f.write(result)

Tacred Accuracy on wikidata
NULL Treshold 220
Top1	T1+Null	Top3	T3+Null	Top5	T5+Null	NULL	N	Group
0.209	0.277	0.346	0.482	0.355	0.529	59	512	AWARD
0.033	0.031	0.076	0.071	0.107	0.100	143	1687	CITY
0.030	0.028	0.112	0.108	0.142	0.133	140	3602	COUNTRY
0.251	0.263	0.299	0.317	0.335	0.353	19	167	DISTRICT
0.144	0.207	0.240	0.334	0.254	0.374	130	575	FACILITY
0.347	0.347	0.551	0.551	0.551	0.551	0	49	LANGUAGE
0.135	0.381	0.156	0.493	0.164	0.548	221	653	LAW
0.066	0.101	0.125	0.184	0.134	0.198	46	424	LOCATION
0.250	0.252	0.384	0.386	0.457	0.459	17	560	NATIONALITY
0.270	0.310	0.392	0.459	0.418	0.493	776	5851	ORGANIZATION
0.197	0.244	0.238	0.296	0.249	0.319	2303	7450	PERSON
0.494	0.501	0.569	0.586	0.571	0.607	29	399	PRODUCT
0.279	0.341	0.445	0.550	0.500	0.630	696	6883	PROFESSION
0.420	0.420	0.629	0.643	0.706	0.720	4	143	RELIGION
0.092	0.094	0.180	0.180	0.221	0.217	26	566	STATE_OR_PROVINCE
0.239	0.303	0.315	0.408	0.348	0.450	69	422	WORK_OF_ART

0.202	0.245	0.299	0.364	0.329	0.406	4682	29962	Total (