In [9]:
import os 
import random
import pickle
from collections import defaultdict

import ujson 
import numpy as np
import copy
import pandas as pd
from tqdm import tqdm

np.random.seed(4680)
random.seed(4680)

SIM_RELATION = "is_similar_to"
COMPL_RELATION = "is_complementary_to"
REL_RELATION = "is_relevant_to"

in_dir = "/home/jupyter/jointly_rec_and_search/datasets/kgc/"

aid_to_sim_pids = {}
aid_to_rel_qids = {}
cid_to_sim_pids = {}
cid_to_rel_qids = {}
cid_to_compl_pids = {}
with open(os.path.join(in_dir, "one_hop_relation.pkl"), "rb") as fin:
    obj_dict = pickle.load(fin)
    aid_to_sim_pids = obj_dict["aid_sim"]
    aid_to_rel_qids = obj_dict["aid_rel"]
    cid_to_sim_pids = obj_dict["compl_sim"]
    cid_to_rel_qids = obj_dict["compl_rel"]
    cid_to_compl_pids = obj_dict["compl_compl"]
    
aid_to_compl_pids = defaultdict(list)
with open(os.path.join(in_dir, "arels.compl.train.tsv")) as fin:
    for line in fin:
        aid, _, pid, _ = line.rstrip().split("\t")
        aid, pid = int(aid), int(pid)
        aid_to_compl_pids[aid].append(pid)
    
print("number of aid_to_sim_pid = {:,}, aid_to_rel_qid = {:,}, cid_to_sim_pid = {:,}, cid_to_rel_qid = {:,}, cid_to_compl_pid = {:,}".format(
    sum([len(xs) for xs in aid_to_sim_pids.values()]), sum([len(xs) for xs in aid_to_rel_qids.values()]), 
    sum([len(xs) for xs in cid_to_sim_pids.values()]),
    sum([len(xs) for xs in cid_to_rel_qids.values()]), sum([len(xs) for xs in cid_to_compl_pids.values()])
))
print("number of aid_to_compl_pid = {:,}".format(sum([len(xs) for xs in aid_to_compl_pids.values()])))

number of aid_to_sim_pid = 519,522, aid_to_rel_qid = 2,676,047, cid_to_sim_pid = 377,449, cid_to_rel_qid = 2,146,221, cid_to_compl_pid = 195,865
number of aid_to_compl_pid = 264,134


In [18]:
# all bm25 neg
from collections import defaultdict
out_dir = "/home/jupyter/jointly_rec_and_search/datasets/kgc/train/bm25_neg/"

all_miss = 0
max5_aid_to_compl_pids = {}
max5_aid_to_sim_pids = {}
max5_aid_to_rel_qids = {}
max5_cid_to_sim_pids = {}
max5_cid_to_rel_qids = {}
max5_cid_to_compl_pids = {}
for aid in aid_to_compl_pids:
    max5_aid_to_compl_pids[aid] = random.sample(aid_to_compl_pids[aid], len(aid_to_compl_pids[aid]))[:5]
    
    # aid -> xxx 
    if aid in aid_to_sim_pids:
        max5_aid_to_sim_pids[aid] = random.sample(aid_to_sim_pids[aid], len(aid_to_sim_pids[aid]))[:5]
    if aid in aid_to_rel_qids:
        max5_aid_to_rel_qids[aid] = random.sample(aid_to_rel_qids[aid], len(aid_to_rel_qids[aid]))[:5]
    
    # cid -> xxx 
    compl_pids = max5_aid_to_compl_pids[aid]
    for cid in compl_pids:
        if cid in cid_to_sim_pids:
            max5_cid_to_sim_pids[cid] = random.sample(cid_to_sim_pids[cid], len(cid_to_sim_pids[cid]))[:5]
        if cid in cid_to_rel_qids:
            max5_cid_to_rel_qids[cid] = random.sample(cid_to_rel_qids[cid], len(cid_to_rel_qids[cid]))[:5]
        if cid in cid_to_compl_pids:
            max5_cid_to_compl_pids[cid] = random.sample(cid_to_compl_pids[cid], len(cid_to_compl_pids[cid]))[:5]
            
        if all([cid not in cid_to_sim_pids, cid not in cid_to_rel_qids, cid not in cid_to_compl_pids]):
            all_miss += 1 
            
print("# line of max5_aid_to_compl_pids = {:,}, max5_aid_to_sim_pids = {:,}, max5_aid_to_rel_qids = {:,}".format(
        len(max5_aid_to_compl_pids), len(max5_aid_to_sim_pids), len(max5_aid_to_rel_qids)))
print("# line of max5_cid_to_compl_pids = {:,}, max5_cid_to_sim_pids = {:,}, max5_cid_to_rel_qids = {:,}".format(
        len(max5_cid_to_compl_pids), len(max5_cid_to_sim_pids), len(max5_cid_to_rel_qids)))
print("-"*75)
print("# of max5_aid_to_compl_pids = {:,}, max5_aid_to_sim_pids = {:,}, max5_aid_to_rel_qids = {:,}".format(
        sum([len(x) for x in max5_aid_to_compl_pids.values()]), sum([len(x) for x in max5_aid_to_sim_pids.values()]),
        sum([len(x) for x in max5_aid_to_rel_qids.values()])))
print("# of of max5_cid_to_compl_pids = {:,}, max5_cid_to_sim_pids = {:,}, max5_cid_to_rel_qids = {:,}".format(
        sum([len(x) for x in max5_cid_to_compl_pids.values()]), sum([len(x) for x in max5_cid_to_sim_pids.values()]),
        sum([len(x) for x in max5_cid_to_rel_qids.values()])))
print("-"*75)
print("all miss = {:,}, miss_rate = {:.3f}".format(all_miss, all_miss/sum([len(x) for x in max5_aid_to_compl_pids.values()])))

# bm25 run
run_path = "/home/jupyter/jointly_rec_and_search/datasets/kgc/runs/bm25.train.run"
df = pd.read_csv(run_path, sep=" ", names=["hid", "q0", "tid", "rank", "score", "model_name"])
bm25_hid_to_tids = {}
ignore_hids = set()
for hid, group in df.groupby("hid"):
    cand_tids = list(group.tid.values)
    if len(cand_tids) < 10:
        ignore_hids.add(int(hid))
    else:
        bm25_hid_to_tids[int(hid)] = [int(x) for x in cand_tids]
        
print("number of ignore hids = {}".format(len(ignore_hids)))

# line of max5_aid_to_compl_pids = 69,496, max5_aid_to_sim_pids = 55,192, max5_aid_to_rel_qids = 62,643
# line of max5_cid_to_compl_pids = 29,270, max5_cid_to_sim_pids = 42,291, max5_cid_to_rel_qids = 48,785
---------------------------------------------------------------------------
# of max5_aid_to_compl_pids = 178,162, max5_aid_to_sim_pids = 218,676, max5_aid_to_rel_qids = 279,656
# of of max5_cid_to_compl_pids = 101,051, max5_cid_to_sim_pids = 166,720, max5_cid_to_rel_qids = 216,724
---------------------------------------------------------------------------
all miss = 5,290, miss_rate = 0.030
number of ignore hids = 4842


FileNotFoundError: [Errno 2] No such file or directory: '/home/jupyter/jointly_rec_and_search/datasets/kgc/train/bm25_neg/max5_triples.a2cp.tsv'

In [119]:
eid_to_text = {}
with open("/home/jupyter/jointly_rec_and_search/datasets/kgc/all_entites.tsv") as fin:
    for line in fin:
        eid, text = line.strip().split("\t")
        eid_to_text[int(eid)] = text

if not os.path.exists(out_dir):
    os.mkdir(out_dir)
# start to create triples
miss_triples = 0
miss_qids = set()
duplicate_pairs = []
with open(os.path.join(out_dir, "max5_triples.a2cp.tsv"), "w") as fout:
    for aid, compl_pids in max5_aid_to_compl_pids.items():
        if aid not in bm25_hid_to_tids:
            miss_triples += 1
            continue
        for pos_cid in compl_pids:
            neg_cid = random.sample(bm25_hid_to_tids[aid], k=1)[0]
            while neg_cid in compl_pids:
                neg_cid = random.sample(bm25_hid_to_tids[aid], k=1)[0]
            fout.write(f"{aid}\t{pos_cid}\t{neg_cid}\t{COMPL_RELATION}\n")
            
print("number of miss triples = {:,}".format(miss_triples))         
print("-"*75)
with open(os.path.join(out_dir, "max5_triples.a2sp.bm25.tsv"), "w") as fout:
    with open(os.path.join(out_dir, "max3_triples.a2sp.bm25.tsv"), "w") as fout2:
        with open(os.path.join(out_dir, "max1_triples.a2sp.bm25.tsv"), "w") as fout3:
            for aid, sim_pids in max5_aid_to_sim_pids.items():
                if aid not in bm25_hid_to_tids:
                    miss_triples += 1
                    continue
                cand_triples = []
                for pos_pid in sim_pids:
                    if eid_to_text[aid] == eid_to_text[pos_pid]:
                        duplicate_pairs.append((aid, pos_pid))
                        continue
                    neg_pid = random.sample(bm25_hid_to_tids[aid][-50:], k=1)[0]
                    while neg_pid in sim_pids:
                        neg_pid = random.sample(bm25_hid_to_tids[aid][-50:], k=1)[0]
                    cand_triples.append((aid, pos_pid, neg_pid))
                for aid, pos_pid, neg_pid in cand_triples[:5]:
                    fout.write(f"{aid}\t{pos_pid}\t{neg_pid}\t{SIM_RELATION}\n")
                for aid, pos_pid, neg_pid in cand_triples[:3]:
                    fout2.write(f"{aid}\t{pos_pid}\t{neg_pid}\t{SIM_RELATION}\n")
                for aid, pos_pid, neg_pid in cand_triples[:1]:
                    fout3.write(f"{aid}\t{pos_pid}\t{neg_pid}\t{SIM_RELATION}\n")
                    
print("number of miss triples = {:,}".format(miss_triples))         
print("-"*75)
with open(os.path.join(out_dir, "max5_triples.a2sp.rnd.tsv"), "w") as fout:
    with open(os.path.join(out_dir, "max3_triples.a2sp.rnd.tsv"), "w") as fout2:
        with open(os.path.join(out_dir, "max1_triples.a2sp.rnd.tsv"), "w") as fout3:
            for aid, sim_pids in max5_aid_to_sim_pids.items():
                if aid not in bm25_hid_to_tids:
                    miss_triples += 1
                    continue
                cand_triples = []
                for pos_pid in sim_pids:
                    if eid_to_text[aid] == eid_to_text[pos_pid]:
                        duplicate_pairs.append((aid, pos_pid))
                        continue
                    neg_pid = random.sample(range(2_000_000), k=1)[0]
                    cand_triples.append((aid, pos_pid, neg_pid))
                for aid, pos_pid, neg_pid in cand_triples[:5]:
                    fout.write(f"{aid}\t{pos_pid}\t{neg_pid}\t{SIM_RELATION}\n")
                for aid, pos_pid, neg_pid in cand_triples[:3]:
                    fout2.write(f"{aid}\t{pos_pid}\t{neg_pid}\t{SIM_RELATION}\n")
                for aid, pos_pid, neg_pid in cand_triples[:1]:
                    fout3.write(f"{aid}\t{pos_pid}\t{neg_pid}\t{SIM_RELATION}\n")

print("number of miss triples = {:,}".format(miss_triples))         
print("-"*75)
with open(os.path.join(out_dir, "max5_triples.q2a.tsv"), "w") as fout:
    with open(os.path.join(out_dir, "max3_triples.q2a.tsv"), "w") as fout2:
        with open(os.path.join(out_dir, "max1_triples.q2a.tsv"), "w") as fout3:
            for aid, qids in max5_aid_to_rel_qids.items():
                cand_triples = []
                for qid in qids[:5]:
                    if qid not in bm25_hid_to_tids:
                        miss_triples += 1
                        miss_qids.add(qid)
                        continue
                    neg_pid = random.sample(bm25_hid_to_tids[qid], k=1)[0]
                    while neg_pid == aid:
                        neg_pid = random.sample(bm25_hid_to_tids[qid], k=1)[0]
                    cand_triples.append((qid, aid, neg_pid))
                for qid, aid, neg_pid in cand_triples[:5]:
                    fout.write(f"{qid}\t{aid}\t{neg_pid}\t{REL_RELATION}\n")
                for qid, aid, neg_pid in cand_triples[:3]:
                    fout2.write(f"{qid}\t{aid}\t{neg_pid}\t{REL_RELATION}\n")
                for qid, aid, neg_pid in cand_triples[:1]:
                    fout3.write(f"{qid}\t{aid}\t{neg_pid}\t{REL_RELATION}\n")

print("number of miss triples = {:,}".format(miss_triples))         
print("-"*75)
with open(os.path.join(out_dir, "max5_triples.c2sp.bm25.tsv"), "w") as fout:
    with open(os.path.join(out_dir, "max3_triples.c2sp.bm25.tsv"), "w") as fout2:
        with open(os.path.join(out_dir, "max1_triples.c2sp.bm25.tsv"), "w") as fout3:
            for aid, sim_pids in max5_cid_to_sim_pids.items():
                if aid not in bm25_hid_to_tids:
                    miss_triples += 1
                    continue
                cand_triples = []
                for pos_pid in sim_pids:
                    if eid_to_text[aid] == eid_to_text[pos_pid]:
                        duplicate_pairs.append((aid, pos_pid))
                        continue
                        
                    neg_pid = random.sample(bm25_hid_to_tids[aid][-50:], k=1)[0]
                    while neg_pid in sim_pids:
                        neg_pid = random.sample(bm25_hid_to_tids[aid][-50:], k=1)[0]
                    cand_triples.append((aid, pos_pid, neg_pid))
                for aid, pos_pid, neg_pid in cand_triples[:5]:
                    fout.write(f"{aid}\t{pos_pid}\t{neg_pid}\t{SIM_RELATION}\n")
                for aid, pos_pid, neg_pid in cand_triples[:3]:
                    fout2.write(f"{aid}\t{pos_pid}\t{neg_pid}\t{SIM_RELATION}\n")
                for aid, pos_pid, neg_pid in cand_triples[:1]:
                    fout3.write(f"{aid}\t{pos_pid}\t{neg_pid}\t{SIM_RELATION}\n") 
                    
print("number of miss triples = {:,}".format(miss_triples))         
print("-"*75)
with open(os.path.join(out_dir, "max5_triples.c2sp.rnd.tsv"), "w") as fout:
    with open(os.path.join(out_dir, "max3_triples.c2sp.rnd.tsv"), "w") as fout2:
        with open(os.path.join(out_dir, "max1_triples.c2sp.rnd.tsv"), "w") as fout3:
            for aid, sim_pids in max5_cid_to_sim_pids.items():
                if aid not in bm25_hid_to_tids:
                    miss_triples += 1
                    continue
                cand_triples = []
                for pos_pid in sim_pids:
                    if eid_to_text[aid] == eid_to_text[pos_pid]:
                        duplicate_pairs.append((aid, pos_pid))
                        continue
                        
                    neg_pid = random.sample(range(2_000_000), k=1)[0]
                    cand_triples.append((aid, pos_pid, neg_pid))
                for aid, pos_pid, neg_pid in cand_triples[:5]:
                    fout.write(f"{aid}\t{pos_pid}\t{neg_pid}\t{SIM_RELATION}\n")
                for aid, pos_pid, neg_pid in cand_triples[:3]:
                    fout2.write(f"{aid}\t{pos_pid}\t{neg_pid}\t{SIM_RELATION}\n")
                for aid, pos_pid, neg_pid in cand_triples[:1]:
                    fout3.write(f"{aid}\t{pos_pid}\t{neg_pid}\t{SIM_RELATION}\n") 

print("number of miss triples = {:,}".format(miss_triples))         
print("-"*75)
with open(os.path.join(out_dir, "max5_triples.q2c.tsv"), "w") as fout:
    with open(os.path.join(out_dir, "max3_triples.q2c.tsv"), "w") as fout2:
        with open(os.path.join(out_dir, "max1_triples.q2c.tsv"), "w") as fout3:
            for aid, qids in max5_cid_to_rel_qids.items():
                cand_triples = []
                for qid in qids[:5]:
                    if qid not in bm25_hid_to_tids:
                        miss_triples += 1
                        miss_qids.add(qid)
                        continue
                    neg_pid = random.sample(bm25_hid_to_tids[qid], k=1)[0]
                    while neg_pid == aid:
                        neg_pid = random.sample(bm25_hid_to_tids[qid], k=1)[0]
                    cand_triples.append((qid, aid, neg_pid))
                for qid, aid, neg_pid in cand_triples[:5]:
                    fout.write(f"{qid}\t{aid}\t{neg_pid}\t{REL_RELATION}\n")
                for qid, aid, neg_pid in cand_triples[:3]:
                    fout2.write(f"{qid}\t{aid}\t{neg_pid}\t{REL_RELATION}\n")
                for qid, aid, neg_pid in cand_triples[:1]:
                    fout3.write(f"{qid}\t{aid}\t{neg_pid}\t{REL_RELATION}\n")

print("number of miss triples = {:,}".format(miss_triples))         
print("-"*75)
with open(os.path.join(out_dir, "max5_triples.c2cp.tsv"), "w") as fout:
    with open(os.path.join(out_dir, "max3_triples.c2cp.tsv"), "w") as fout2:
        with open(os.path.join(out_dir, "max1_triples.c2cp.tsv"), "w") as fout3:
            for aid, compl_pids in max5_cid_to_compl_pids.items():
                if aid not in bm25_hid_to_tids:
                    miss_triples += 1
                    continue
                cand_triples = []
                for pos_cid in compl_pids:
                    neg_cid = random.sample(bm25_hid_to_tids[aid], k=1)[0]
                    while neg_cid in compl_pids:
                        neg_cid = random.sample(bm25_hid_to_tids[aid], k=1)[0]
                    cand_triples.append((aid, pos_cid, neg_cid))
                for aid, pos_cid, neg_cid in cand_triples[:5]:
                    fout.write(f"{aid}\t{pos_cid}\t{neg_cid}\t{COMPL_RELATION}\n")
                for aid, pos_cid, neg_cid in cand_triples[:3]:
                    fout2.write(f"{aid}\t{pos_cid}\t{neg_cid}\t{COMPL_RELATION}\n")
                for aid, pos_cid, neg_cid in cand_triples[:1]:
                    fout3.write(f"{aid}\t{pos_cid}\t{neg_cid}\t{COMPL_RELATION}\n")
                    
print("number of miss triples = {:,}".format(miss_triples))         
print("-"*75)
print("number of miss qids = {:,}".format(len(miss_qids)))
print("duplicate pairs in similar_relation = {:,}".format(len(duplicate_pairs)))

number of miss triples = 0
---------------------------------------------------------------------------
number of miss triples = 0
---------------------------------------------------------------------------
number of miss triples = 0
---------------------------------------------------------------------------
number of miss triples = 9,476
---------------------------------------------------------------------------
number of miss triples = 9,476
---------------------------------------------------------------------------
number of miss triples = 9,476
---------------------------------------------------------------------------
number of miss triples = 16,899
---------------------------------------------------------------------------
number of miss triples = 16,899
---------------------------------------------------------------------------
number of miss qids = 12,417
duplicate pairs in similar_relation = 30,582


In [120]:
# sanity check
for path in os.listdir(out_dir):
    path = os.path.join(out_dir, path)
    if "max5" not in path:
        continue
    ! wc -l $path
    ! head -n 3 $path
    ! tail -n 3 $path
    print("="*100)

209875 /home/jupyter/jointly_rec_and_search/datasets/kgc/train/bm25_neg/max5_triples.a2sp.bm25.tsv
240470	1320299	933444	is_similar_to
240470	2058234	1964489	is_similar_to
240470	877953	238288	is_similar_to
1462385	417274	2239696	is_similar_to
1462385	2151833	2239696	is_similar_to
1462385	1561468	718072	is_similar_to
160230 /home/jupyter/jointly_rec_and_search/datasets/kgc/train/bm25_neg/max5_triples.c2sp.bm25.tsv
2065943	4930	1620254	is_similar_to
2065943	1350909	1807818	is_similar_to
2065943	49613	2203542	is_similar_to
610862	1306820	1129289	is_similar_to
610862	1630954	1056892	is_similar_to
610862	1598527	537589	is_similar_to
178162 /home/jupyter/jointly_rec_and_search/datasets/kgc/train/bm25_neg/max5_triples.a2cp.tsv
1106381	761749	1643479	is_complementary_to
737352	1662290	2122155	is_complementary_to
737352	2065943	1122750	is_complementary_to
329141	736039	1176457	is_complementary_to
329141	610862	1284066	is_complementary_to
1462385	327543	2220960	is_complementary_to
209875 /home/

In [118]:
hid, pos_tid, neg_tid = (610862,1306820,356893)

! grep -P "^{hid}\t" "/home/jupyter/jointly_rec_and_search/datasets/kgc/all_entites.tsv"
! grep -P "^{pos_tid}\t" "/home/jupyter/jointly_rec_and_search/datasets/kgc/all_entites.tsv"
! grep -P "^{neg_tid}\t" "/home/jupyter/jointly_rec_and_search/datasets/kgc/all_entites.tsv"

print("-"*75)

hid, pos_tid, neg_tid = (240470,1320299,97079)

! grep -P "^{hid}\t" "/home/jupyter/jointly_rec_and_search/datasets/kgc/all_entites.tsv"
! grep -P "^{pos_tid}\t" "/home/jupyter/jointly_rec_and_search/datasets/kgc/all_entites.tsv"
! grep -P "^{neg_tid}\t" "/home/jupyter/jointly_rec_and_search/datasets/kgc/all_entites.tsv"

610862	Jacuzzi 1.5-in Brushed Nickel Foot Lock Drain with Plastic Pipe ; Bathtub Drains
1306820	Keeney 1.5-in Brushed Nickel Foot Lock Drain with Polypropylene Pipe ; Bathtub Drains
356893	Ekena Millwork Sandblasted Endurathane Faux Wood 6-in x 10-in x 96-in Aged Ash Prefinished Polyurethane Decorative Beam ; Faux Beams
---------------------------------------------------------------------------
240470	Kraloy Gray PVC Weatherproof New Work/Old Work Standard Enclosure Exterior Electrical Box ; Electrical Boxes
1320299	Sigma Engineered Solutions 1-Gang Weatherproof Box 1-Gang Gray Metal Weatherproof New Work Standard Rectangular Exterior Electrical Box ; Electrical Boxes
97079	Southwire 1-Gang Gray Polycarbonate New Work/Old Work Standard Adjustable Ceiling Electrical Box ; Electrical Boxes


In [89]:
eid_to_text = {}
with open("/home/jupyter/jointly_rec_and_search/datasets/kgc/all_entites.tsv") as fin:
    for line in fin:
        eid, text = line.strip().split("\t")
        eid_to_text[eid] = text

total_num = 0
equal_pairs = []
with open("/home/jupyter/jointly_rec_and_search/datasets/kgc/train/bm25_neg/max5_triples.a2sp.tsv") as fin:
    for line in fin:
        hid, tid, _, _ = line.strip().split("\t")
        if eid_to_text[hid] == eid_to_text[tid]:
            equal_pairs.append((hid, tid))
        total_num += 1
            
with open("/home/jupyter/jointly_rec_and_search/datasets/kgc/train/bm25_neg/max5_triples.c2sp.tsv") as fin:
    for line in fin:
        hid, tid, _, _ = line.strip().split("\t")
        if eid_to_text[hid] == eid_to_text[tid]:
            equal_pairs.append((hid, tid))
        total_num += 1

print("number of equal_pairs = {:,}, rate = {:.3f}".format(len(equal_pairs), len(equal_pairs)/total_num))
# hid, pos_tid, neg_tid = (610862,1598527,589685)

#! grep -P "^{hid}\t" "/home/jupyter/jointly_rec_and_search/datasets/kgc/all_entites.tsv"
#! grep -P "^{pos_tid}\t" "/home/jupyter/jointly_rec_and_search/datasets/kgc/all_entites.tsv"
#! grep -P "^{neg_tid}\t" "/home/jupyter/jointly_rec_and_search/datasets/kgc/all_entites.tsv"

number of equal_pairs = 15,291, rate = 0.040


In [None]:
import os 
import random

import ujson 
import numpy as np
import copy
import pandas as pd
from tqdm import tqdm

run_path = "/home/jupyter/jointly_rec_and_search/datasets/kgc/runs/bm25.train.run"

if not os.path.exists(out_dir):
    os.mkdir(out_dir)
    
df = pd.read_csv(run_path, sep=" ", names=["aid", "q0", "pid", "rank", "score", "model_name"])
bm25_hid_to_pids = {}
ignore_aids = set()
for aid, group in df.groupby("aid"):
    cand_pids = list(group.pid.values)
    if len(cand_pids) < 10:
        ignore_aids.add(int(aid))
    else:
        bm25_hid_to_pids[int(aid)] = [int(x) for x in cand_pids]
        
assert len(ignore_aids) == 0

train_aids = set()
with open(os.path.join(in_dir, "anchors_title.train.tsv")) as fin:
    for line in fin:
        aid, title = line.rstrip().split("\t")
        train_aids.add(int(aid))

one_hop_examples = []
with open(os.path.join(in_dir, "one_hop_examples.jsonl")) as fin:
    for line in fin:
        example = ujson.loads(line.rstrip())
        for key in example:
            assert len(example) == 2
            if key != "compl_pids":
                aid = int(key)
                
        if aid in train_aids:
            one_hop_examples.append(example)

ignore_num = 0
no_enough_simpids = 0
for topk in [5]:
    train_examples = []
    for example in tqdm(one_hop_examples, total=len(one_hop_examples)):
        train_exp = {"compl_pids": [], "sim_pids": []}
        assert len(example) == 2, example
        for key, vals in example.items():
            if key == "compl_pids":
                compl_pids = [int(pid) for pid in vals.keys()]
                train_exp["compl_pids"] = compl_pids
                random.shuffle(compl_pids)
                train_exp["compl_pids"] = compl_pids[:topk]
        for key, vals in example.items():
            if key != "compl_pids":
                train_exp["aid"] = int(key)
                sim_pids = [int(pid) for pid in example[key]["sim_pids"]]
                random.shuffle(sim_pids)
                train_exp["sim_pids"] = sim_pids[:len(train_exp["compl_pids"])]
                
                if len(train_exp["sim_pids"]) < len(train_exp["compl_pids"]):
                    no_enough_simpids += 1
                    cand_pids = bm25_hid_to_pids[train_exp["aid"]]
                    random.shuffle(cand_pids)
                    
                    remain_pids = []      
                    for pid in cand_pids:
                        if pid in compl_pids:
                            continue
                        remain_pids.append(pid)
                        if len(remain_pids) >= len(train_exp["compl_pids"]) - len(train_exp["sim_pids"]):
                            break
                    train_exp["sim_pids"] += remain_pids
                        
                    
        train_examples.append(train_exp)
        
    with open(os.path.join(out_dir, f"top{topk}_triples.tsv"), "w") as fout:
        for train_exp in train_examples:
            aid = train_exp["aid"]
            pos_neg_pairs = zip(train_exp["compl_pids"], train_exp["sim_pids"])
            for pos_pid, neg_pid in pos_neg_pairs:
                fout.write(f"{aid}\t{pos_pid}\t{neg_pid}\n")
                
    print(f"no_enough_simpids for top-{topk}: ", no_enough_simpids)
            
    

In [20]:
# sanity check
for path in os.listdir(out_dir):
    path = os.path.join(out_dir, path)
    ! wc -l $path
    ! head -n 3 $path
    ! tail -n 3 $path
    print("="*100)

69496 /home/jupyter/jointly_rec_and_search/datasets/rec_search/rec_compl/train/train_5compl_5sim.json
{"compl_pids":[2186809],"sim_pids":[1296125],"aid":501931}
{"compl_pids":[1960065,774261],"sim_pids":[500345,2128382],"aid":220006}
{"compl_pids":[855403,1654671,565304,1939891,1185215],"sim_pids":[1952428,70274,393439,414956,351113],"aid":1970525}
{"compl_pids":[273456],"sim_pids":[672057],"aid":1822620}
{"compl_pids":[533708,1043030],"sim_pids":[485389,1381526],"aid":132127}
{"compl_pids":[1870322],"sim_pids":[1889462],"aid":385303}
69496 /home/jupyter/jointly_rec_and_search/datasets/rec_search/rec_compl/train/train_2compl_2sim.json
{"compl_pids":[2186809],"sim_pids":[1890315],"aid":501931}
{"compl_pids":[774261,1960065],"sim_pids":[1702313,563604],"aid":220006}
{"compl_pids":[2257172,586807],"sim_pids":[393439,1952428],"aid":1970525}
{"compl_pids":[273456],"sim_pids":[971697],"aid":1822620}
{"compl_pids":[533708,1043030],"sim_pids":[1078725,485389],"aid":132127}
{"compl_pids":[18703

In [4]:
! grep -P "^132127\t" "/home/jupyter/jointly_rec_and_search/datasets/kgc/collection_title_catalog.tsv"
! grep -P "^1043030\t" "/home/jupyter/jointly_rec_and_search/datasets/kgc/collection_title_catalog.tsv"
! grep -P "^1381526\t" "/home/jupyter/jointly_rec_and_search/datasets/kgc/collection_title_catalog.tsv"

132127	Jacuzzi 1500-Watt Inline Heater [SEP] Whirlpool Tub & Air Bath Parts
1043030	Jacuzzi 1.5-in Brushed Nickel Foot Lock Drain with Plastic Pipe [SEP] Bathtub Drains
1381526	WaterTECH Whirpool 110 Volt tub heater [SEP] Whirlpool Tub & Air Bath Parts


In [None]:
path = "/home/jupyter/jointly_rec_and_search/preprocess/rec_compl/dataset//"