In [1]:
import os 
import random
import pickle
from collections import defaultdict

import ujson 
import numpy as np
import copy
import pandas as pd
from tqdm import tqdm

np.random.seed(4680)
random.seed(4680)

SIM_RELATION = "is_similar_to"
COMPL_RELATION = "is_complementary_to"
REL_RELATION = "is_relevant_to"

in_dir = "/home/jupyter/jointly_rec_and_search/datasets/kgc/"

aid_to_sim_pids = {}
aid_to_rel_qids = {}
cid_to_sim_pids = {}
cid_to_rel_qids = {}
cid_to_compl_pids = {}
with open(os.path.join(in_dir, "one_hop_relation.pkl"), "rb") as fin:
    obj_dict = pickle.load(fin)
    aid_to_sim_pids = obj_dict["aid_sim"]
    aid_to_rel_qids = obj_dict["aid_rel"]
    cid_to_sim_pids = obj_dict["compl_sim"]
    cid_to_rel_qids = obj_dict["compl_rel"]
    cid_to_compl_pids = obj_dict["compl_compl"]
    
aid_to_compl_pids = defaultdict(list)
with open(os.path.join(in_dir, "arels.compl.train.tsv")) as fin:
    for line in fin:
        aid, _, pid, _ = line.rstrip().split("\t")
        aid, pid = int(aid), int(pid)
        aid_to_compl_pids[aid].append(pid)
    
print("number of aid_to_sim_pid = {:,}, aid_to_rel_qid = {:,}, cid_to_sim_pid = {:,}, cid_to_rel_qid = {:,}, cid_to_compl_pid = {:,}".format(
    sum([len(xs) for xs in aid_to_sim_pids.values()]), sum([len(xs) for xs in aid_to_rel_qids.values()]), 
    sum([len(xs) for xs in cid_to_sim_pids.values()]),
    sum([len(xs) for xs in cid_to_rel_qids.values()]), sum([len(xs) for xs in cid_to_compl_pids.values()])
))
print("number of aid_to_compl_pid = {:,}".format(sum([len(xs) for xs in aid_to_compl_pids.values()])))

number of aid_to_sim_pid = 519,522, aid_to_rel_qid = 2,676,047, cid_to_sim_pid = 377,449, cid_to_rel_qid = 2,146,221, cid_to_compl_pid = 195,865
number of aid_to_compl_pid = 264,134


In [2]:
# all dual negs
out_dir = "/home/jupyter/jointly_rec_and_search/datasets/kgc/train/dual_neg/"
all_miss = 0
max5_aid_to_compl_pids = {}
max5_aid_to_sim_pids = {}
max5_aid_to_rel_qids = {}
max5_cid_to_compl_pids = {}
max5_cid_to_sim_pids = {}
max5_cid_to_rel_qids = {}

for aid in aid_to_compl_pids:
    max5_aid_to_compl_pids[aid] = random.sample(aid_to_compl_pids[aid], len(aid_to_compl_pids[aid]))[:5]
    
    # aid -> xxx 
    if aid in aid_to_sim_pids:
        max5_aid_to_sim_pids[aid] = random.sample(aid_to_sim_pids[aid], len(aid_to_sim_pids[aid]))[:5]
    if aid in aid_to_rel_qids:
        max5_aid_to_rel_qids[aid] = random.sample(aid_to_rel_qids[aid], len(aid_to_rel_qids[aid]))[:5]
    
    # cid -> xxx 
    compl_pids = max5_aid_to_compl_pids[aid]
    for cid in compl_pids:
        if cid in cid_to_sim_pids:
            max5_cid_to_sim_pids[cid] = random.sample(cid_to_sim_pids[cid], len(cid_to_sim_pids[cid]))[:5]
        if cid in cid_to_rel_qids:
            max5_cid_to_rel_qids[cid] = random.sample(cid_to_rel_qids[cid], len(cid_to_rel_qids[cid]))[:5]
        if cid in cid_to_compl_pids:
            max5_cid_to_compl_pids[cid] = random.sample(cid_to_compl_pids[cid], len(cid_to_compl_pids[cid]))[:5]
            
        if all([cid not in cid_to_sim_pids, cid not in cid_to_rel_qids, cid not in cid_to_compl_pids]):
            all_miss += 1 
            
print("# line of max5_aid_to_compl_pids = {:,}, max5_aid_to_sim_pids = {:,}, max5_aid_to_rel_qids = {:,}".format(
        len(max5_aid_to_compl_pids), len(max5_aid_to_sim_pids), len(max5_aid_to_rel_qids)))
print("# line of max5_cid_to_compl_pids = {:,}, max5_cid_to_sim_pids = {:,}, max5_cid_to_rel_qids = {:,}".format(
        len(max5_cid_to_compl_pids), len(max5_cid_to_sim_pids), len(max5_cid_to_rel_qids)))
print("-"*75)
print("# of max5_aid_to_compl_pids = {:,}, max5_aid_to_sim_pids = {:,}, max5_aid_to_rel_qids = {:,}".format(
        sum([len(x) for x in max5_aid_to_compl_pids.values()]), sum([len(x) for x in max5_aid_to_sim_pids.values()]),
        sum([len(x) for x in max5_aid_to_rel_qids.values()])))
print("# of of max5_cid_to_compl_pids = {:,}, max5_cid_to_sim_pids = {:,}, max5_cid_to_rel_qids = {:,}".format(
        sum([len(x) for x in max5_cid_to_compl_pids.values()]), sum([len(x) for x in max5_cid_to_sim_pids.values()]),
        sum([len(x) for x in max5_cid_to_rel_qids.values()])))
print("-"*75)
print("all miss = {:,}, miss_rate = {:.3f}".format(all_miss, all_miss/sum([len(x) for x in max5_aid_to_compl_pids.values()])))

# bm25 run
run_path = "/home/jupyter/jointly_rec_and_search/datasets/kgc/runs/bm25.train.run"
df = pd.read_csv(run_path, sep=" ", names=["hid", "q0", "tid", "rank", "score", "model_name"])
bm25_hid_to_tids = {}
ignore_hids = set()
for hid, group in df.groupby("hid"):
    cand_tids = list(group.tid.values)
    if len(cand_tids) < 10:
        ignore_hids.add(int(hid))
    else:
        bm25_hid_to_tids[int(hid)] = [int(x) for x in cand_tids]
        
print("number of ignore hids = {}".format(len(ignore_hids)))


# line of max5_aid_to_compl_pids = 69,496, max5_aid_to_sim_pids = 55,192, max5_aid_to_rel_qids = 62,643
# line of max5_cid_to_compl_pids = 29,305, max5_cid_to_sim_pids = 42,380, max5_cid_to_rel_qids = 48,866
---------------------------------------------------------------------------
# of max5_aid_to_compl_pids = 178,162, max5_aid_to_sim_pids = 218,676, max5_aid_to_rel_qids = 279,656
# of of max5_cid_to_compl_pids = 101,230, max5_cid_to_sim_pids = 166,981, max5_cid_to_rel_qids = 217,139
---------------------------------------------------------------------------
all miss = 5,306, miss_rate = 0.030
number of ignore hids = 4842


In [9]:
eid_to_text = {}
with open("/home/jupyter/jointly_rec_and_search/datasets/kgc/all_entites.tsv") as fin:
    for line in fin:
        eid, text = line.strip().split("\t")
        eid_to_text[int(eid)] = text

if not os.path.exists(out_dir):
    os.mkdir(out_dir)
# start to create triples
miss_triples = 0
miss_qids = set()
duplicate_pairs = []
with open(os.path.join(out_dir, "max5_triples.a2cp.tsv"), "w") as fout:
    for aid, compl_pids in max5_aid_to_compl_pids.items():
        assert aid in bm25_hid_to_tids
        
        neg_pids = []
        sim_pids = list(set(aid_to_sim_pids[aid]) - set(compl_pids))
        neg_sim_pids = random.sample(sim_pids, k=len(sim_pids))[:5]
        assert len(set(neg_sim_pids) & set(compl_pids)) == 0 
        neg_pids += neg_sim_pids
        while len(neg_pids) <= 5:          
            neg_pid = random.sample(bm25_hid_to_tids[aid], k=1)[0]
            while neg_pid in compl_pids:
                neg_pid = random.sample(bm25_hid_to_tids[aid], k=1)[0]
            neg_pids.append(neg_pid)          
        neg_pids = neg_pids[: len(compl_pids)]
        
        for pos_pid, neg_pid in zip(compl_pids, neg_pids):
            fout.write(f"{aid}\t{pos_pid}\t{neg_pid}\t{COMPL_RELATION}\n")
                  
print("number of miss triples = {:,}".format(miss_triples))         
print("-"*75)
with open(os.path.join(out_dir, "max5_triples.a2sp.rnd.tsv"), "w") as fout:
    with open(os.path.join(out_dir, "max3_triples.a2sp.rnd.tsv"), "w") as fout2:
        with open(os.path.join(out_dir, "max1_triples.a2sp.rnd.tsv"), "w") as fout3:
            for aid, sim_pids in max5_aid_to_sim_pids.items():
                assert aid in bm25_hid_to_tids
                    
                true_sim_pids = []
                for pos_pid in sim_pids:
                    if eid_to_text[aid] == eid_to_text[pos_pid]:
                        duplicate_pairs.append((aid, pos_pid))
                    else:
                        true_sim_pids.append(pos_pid)
                        
                neg_pids = []
                compl_pids = list(set(aid_to_compl_pids[aid]) - set(true_sim_pids))
                neg_compl_pids = random.sample(compl_pids, k=len(compl_pids))[:5]
                assert len(set(neg_compl_pids) & set(true_sim_pids)) == 0
                neg_pids += neg_compl_pids
                while len(neg_pids) <= 5:
                    neg_pid = random.sample(range(2_000_000), k=1)[0]
                    while neg_pid in sim_pids:
                        neg_pid = random.sample(range(2_000_000), k=1)[0]
                    neg_pids.append(neg_pid)
                neg_pids = neg_pids[:len(true_sim_pids)]
                
                cand_triples = []
                for pos_pid, neg_pid in zip(true_sim_pids, neg_pids):
                    cand_triples.append((aid, pos_pid, neg_pid))
                    
                for aid, pos_pid, neg_pid in cand_triples[:5]:
                    fout.write(f"{aid}\t{pos_pid}\t{neg_pid}\t{SIM_RELATION}\n")
                for aid, pos_pid, neg_pid in cand_triples[:3]:
                    fout2.write(f"{aid}\t{pos_pid}\t{neg_pid}\t{SIM_RELATION}\n")
                for aid, pos_pid, neg_pid in cand_triples[:1]:
                    fout3.write(f"{aid}\t{pos_pid}\t{neg_pid}\t{SIM_RELATION}\n")

print("number of miss triples = {:,}".format(miss_triples))         
print("-"*75)
with open(os.path.join(out_dir, "max5_triples.q2a.tsv"), "w") as fout:
    with open(os.path.join(out_dir, "max3_triples.q2a.tsv"), "w") as fout2:
        with open(os.path.join(out_dir, "max1_triples.q2a.tsv"), "w") as fout3:
            for aid, qids in max5_aid_to_rel_qids.items():
                cand_triples = []
                for qid in qids[:5]:
                    if qid not in bm25_hid_to_tids:
                        miss_triples += 1
                        miss_qids.add(qid)
                        continue
                    neg_pid = random.sample(bm25_hid_to_tids[qid], k=1)[0]
                    while neg_pid == aid:
                        neg_pid = random.sample(bm25_hid_to_tids[qid], k=1)[0]
                    cand_triples.append((qid, aid, neg_pid))
                for qid, aid, neg_pid in cand_triples[:5]:
                    fout.write(f"{qid}\t{aid}\t{neg_pid}\t{REL_RELATION}\n")
                for qid, aid, neg_pid in cand_triples[:3]:
                    fout2.write(f"{qid}\t{aid}\t{neg_pid}\t{REL_RELATION}\n")
                for qid, aid, neg_pid in cand_triples[:1]:
                    fout3.write(f"{qid}\t{aid}\t{neg_pid}\t{REL_RELATION}\n")
                    
print("number of miss triples = {:,}".format(miss_triples))         
print("-"*75)
with open(os.path.join(out_dir, "max5_triples.c2sp.rnd.tsv"), "w") as fout:
    with open(os.path.join(out_dir, "max3_triples.c2sp.rnd.tsv"), "w") as fout2:
        with open(os.path.join(out_dir, "max1_triples.c2sp.rnd.tsv"), "w") as fout3:
            for aid, sim_pids in max5_cid_to_sim_pids.items():
                assert aid in bm25_hid_to_tids
                    
                true_sim_pids = []
                for pos_pid in sim_pids:
                    if eid_to_text[aid] == eid_to_text[pos_pid]:
                        duplicate_pairs.append((aid, pos_pid))
                    else:
                        true_sim_pids.append(pos_pid)
                        
                neg_pids = []
                compl_pids = list(set(cid_to_compl_pids[aid]) - set(true_sim_pids))
                neg_compl_pids = random.sample(compl_pids, k=len(compl_pids))[:5]
                assert len(set(neg_compl_pids) & set(true_sim_pids)) == 0
                neg_pids += neg_compl_pids
                while len(neg_pids) <= 5:
                    neg_pid = random.sample(range(2_000_000), k=1)[0]
                    while neg_pid in sim_pids:
                        neg_pid = random.sample(range(2_000_000), k=1)[0]
                    neg_pids.append(neg_pid)
                neg_pids = neg_pids[:len(true_sim_pids)]
                
                cand_triples = []
                for pos_pid, neg_pid in zip(true_sim_pids, neg_pids):
                    cand_triples.append((aid, pos_pid, neg_pid))
                
                for aid, pos_pid, neg_pid in cand_triples[:5]:
                    fout.write(f"{aid}\t{pos_pid}\t{neg_pid}\t{SIM_RELATION}\n")
                for aid, pos_pid, neg_pid in cand_triples[:3]:
                    fout2.write(f"{aid}\t{pos_pid}\t{neg_pid}\t{SIM_RELATION}\n")
                for aid, pos_pid, neg_pid in cand_triples[:1]:
                    fout3.write(f"{aid}\t{pos_pid}\t{neg_pid}\t{SIM_RELATION}\n") 

print("number of miss triples = {:,}".format(miss_triples))         
print("-"*75)
with open(os.path.join(out_dir, "max5_triples.q2c.tsv"), "w") as fout:
    with open(os.path.join(out_dir, "max3_triples.q2c.tsv"), "w") as fout2:
        with open(os.path.join(out_dir, "max1_triples.q2c.tsv"), "w") as fout3:
            for aid, qids in max5_cid_to_rel_qids.items():
                cand_triples = []
                for qid in qids[:5]:
                    if qid not in bm25_hid_to_tids:
                        miss_triples += 1
                        miss_qids.add(qid)
                        continue
                    neg_pid = random.sample(bm25_hid_to_tids[qid], k=1)[0]
                    while neg_pid == aid:
                        neg_pid = random.sample(bm25_hid_to_tids[qid], k=1)[0]
                    cand_triples.append((qid, aid, neg_pid))
                for qid, aid, neg_pid in cand_triples[:5]:
                    fout.write(f"{qid}\t{aid}\t{neg_pid}\t{REL_RELATION}\n")
                for qid, aid, neg_pid in cand_triples[:3]:
                    fout2.write(f"{qid}\t{aid}\t{neg_pid}\t{REL_RELATION}\n")
                for qid, aid, neg_pid in cand_triples[:1]:
                    fout3.write(f"{qid}\t{aid}\t{neg_pid}\t{REL_RELATION}\n")

print("number of miss triples = {:,}".format(miss_triples))         
print("-"*75)
with open(os.path.join(out_dir, "max5_triples.c2cp.tsv"), "w") as fout:
    with open(os.path.join(out_dir, "max3_triples.c2cp.tsv"), "w") as fout2:
        with open(os.path.join(out_dir, "max1_triples.c2cp.tsv"), "w") as fout3:
            for aid, compl_pids in max5_cid_to_compl_pids.items():
                assert aid in bm25_hid_to_tids
                    
                neg_pids = []
                sim_pids = list(set(cid_to_sim_pids[aid]) - set(compl_pids))
                neg_sim_pids = random.sample(sim_pids, k=len(sim_pids))[:5]
                assert len(set(neg_sim_pids) & set(compl_pids)) == 0 
                neg_pids += neg_sim_pids
                while len(neg_pids) <= 5:          
                    neg_pid = random.sample(bm25_hid_to_tids[aid], k=1)[0]
                    while neg_pid in compl_pids:
                        neg_pid = random.sample(bm25_hid_to_tids[aid], k=1)[0]
                    neg_pids.append(neg_pid)          
                neg_pids = neg_pids[: len(compl_pids)]
                
                cand_triples = []
                for pos_pid, neg_pid in zip(compl_pids, neg_pids):
                    cand_triples.append((aid, pos_pid, neg_pid))
                    
                for aid, pos_cid, neg_cid in cand_triples[:5]:
                    fout.write(f"{aid}\t{pos_cid}\t{neg_cid}\t{COMPL_RELATION}\n")
                for aid, pos_cid, neg_cid in cand_triples[:3]:
                    fout2.write(f"{aid}\t{pos_cid}\t{neg_cid}\t{COMPL_RELATION}\n")
                for aid, pos_cid, neg_cid in cand_triples[:1]:
                    fout3.write(f"{aid}\t{pos_cid}\t{neg_cid}\t{COMPL_RELATION}\n")
                    
print("number of miss triples = {:,}".format(miss_triples))         
print("-"*75)
print("number of miss qids = {:,}".format(len(miss_qids)))
print("duplicate pairs in similar_relation = {:,}".format(len(duplicate_pairs)))

number of miss triples = 0
---------------------------------------------------------------------------
number of miss triples = 0
---------------------------------------------------------------------------
number of miss triples = 9,636
---------------------------------------------------------------------------
number of miss triples = 9,636
---------------------------------------------------------------------------
number of miss triples = 16,940
---------------------------------------------------------------------------
number of miss triples = 16,940
---------------------------------------------------------------------------
number of miss qids = 12,389
duplicate pairs in similar_relation = 15,434


In [10]:
# sanity check
for path in os.listdir(out_dir):
    path = os.path.join(out_dir, path)
    if "max5" not in path:
        continue
    ! wc -l $path
    ! head -n 3 $path
    ! tail -n 3 $path
    print("="*100)

178162 /home/jupyter/jointly_rec_and_search/datasets/kgc/train/dual_neg/max5_triples.a2cp.tsv
1106381	761749	692552	is_complementary_to
737352	1662290	546220	is_complementary_to
737352	2065943	1209888	is_complementary_to
329141	610862	830423	is_complementary_to
329141	736039	46877	is_complementary_to
1462385	327543	108975	is_complementary_to
209815 /home/jupyter/jointly_rec_and_search/datasets/kgc/train/dual_neg/max5_triples.a2sp.rnd.tsv
240470	1320299	2140008	is_similar_to
240470	691546	692826	is_similar_to
240470	877953	1212026	is_similar_to
1462385	1952507	892522	is_similar_to
1462385	1430265	303243	is_similar_to
1462385	417274	455866	is_similar_to
270020 /home/jupyter/jointly_rec_and_search/datasets/kgc/train/dual_neg/max5_triples.q2a.tsv
2465451	1106381	356338	is_relevant_to
2599283	1106381	1572209	is_relevant_to
2828662	1106381	478421	is_relevant_to
2784291	1462385	495956	is_relevant_to
2834538	1462385	1263811	is_relevant_to
2568463	1462385	1418355	is_relevant_to
209835 /home/jup

In [27]:
hid, pos_tid, neg_tid = (2828662,1106381,478421)

! grep -P "^{hid}\t" "/home/jupyter/jointly_rec_and_search/datasets/kgc/all_entites.tsv"
! grep -P "^{pos_tid}\t" "/home/jupyter/jointly_rec_and_search/datasets/kgc/all_entites.tsv"
! grep -P "^{neg_tid}\t" "/home/jupyter/jointly_rec_and_search/datasets/kgc/all_entites.tsv"

2828662	front doors
1106381	SIMPSON 72-in x 80-in Wood Full Lite Right-Hand Inswing Brown Unfinished Double Front Door Solid Core ; Front Doors
478421	MMI DOOR 72-in x 80-in Fiberglass Half Lite Right-Hand Inswing Black Painted Prehung Double Front Door with Brickmould ; Front Doors
