In [9]:
import os 
import random
import pickle
from collections import defaultdict

import ujson 
import numpy as np
import copy
import pandas as pd
from tqdm import tqdm

np.random.seed(4680)
random.seed(4680)

SIM_RELATION = "is_similar_to"
COMPL_RELATION = "is_complementary_to"
REL_RELATION = "is_relevant_to"

in_dir = "/home/jupyter/jointly_rec_and_search/datasets/kgc_search/"

pid_to_simpids = {}
pid_to_complpids = {}
pid_to_qids = {}

with open(os.path.join(in_dir, "one_hop_relation.pkl"), "rb") as fin:
    obj_dict = pickle.load(fin)
    
    pid_to_simpids = obj_dict["item_sim"]
    pid_to_complpids = obj_dict["item_compl"]
    pid_to_qids = obj_dict["item_rel"]
    
train_qid_to_pids = defaultdict(list)
with open(os.path.join(in_dir, "qrels.train.tsv")) as fin:
    for line in fin:
        qid, _, pid, _ = line.rstrip().split("\t")
        qid, pid = int(qid), int(pid)
        train_qid_to_pids[qid].append(pid)

print("number of (pid,simpid) = {:,}, (pid, complpid) = {:,}, (pid, qid) = {:,}, (train_qid, pid) = {:,}".format(
    sum([len(xs) for xs in pid_to_simpids.values()]), sum([len(xs) for xs in pid_to_complpids.values()]), 
    sum([len(xs) for xs in pid_to_qids.values()]), sum([len(xs) for xs in train_qid_to_pids.values()])))

number of (pid,simpid) = 888,642, (pid, complpid) = 316,765, (pid, qid) = 2,923,831, (train_qid, pid) = 3,259,904


In [10]:
from collections import defaultdict

all_miss = 0
all_pairs = 0
max5_pid_to_simpids = {}
max5_pid_to_complpids = {}
max5_pid_to_qids = {}

for train_qid, anchor_pids in tqdm(train_qid_to_pids.items(), total=len(train_qid_to_pids)):
    for a_pid in anchor_pids:
        if a_pid in pid_to_simpids:
            max5_pid_to_simpids[a_pid] = random.sample(pid_to_simpids[a_pid], len(pid_to_simpids[a_pid]))[:5]
        if a_pid in pid_to_complpids:
            max5_pid_to_complpids[a_pid] = random.sample(pid_to_complpids[a_pid], len(pid_to_complpids[a_pid]))[:5]
        if a_pid in pid_to_qids:
            max5_pid_to_qids[a_pid] = random.sample(pid_to_qids[a_pid], len(pid_to_qids[a_pid]))[:5]
        if all([a_pid not in pid_to_simpids, a_pid not in pid_to_complpids, a_pid not in pid_to_qids]):
            all_miss += 1
        all_pairs += 1
print("all_pairs = {:,}, all_miss = {:,} ({:.3f})".format(all_pairs, all_miss, all_miss/all_pairs))

# bm25 run

run_path = "/home/jupyter/jointly_rec_and_search/datasets/kgc_search/runs/bm25.train.run"
df = pd.read_csv(run_path, sep=" ", names=["hid", "q0", "tid", "rank", "score", "model_name"])
bm25_hid_to_tids = {}
ignore_hids = set()
for hid, group in df.groupby("hid"):
    cand_tids = list(group.tid.values)
    if len(cand_tids) < 20:
        ignore_hids.add(int(hid))
    else:
        bm25_hid_to_tids[int(hid)] = [int(x) for x in cand_tids]
        
print("number of ignore hids = {}".format(len(ignore_hids)))


100%|██████████| 763019/763019 [03:04<00:00, 4140.24it/s]


all_pairs = 3,259,904, all_miss = 83,071 (0.025)
number of ignore hids = 7158


In [13]:
eid_to_text = {}
with open("/home/jupyter/jointly_rec_and_search/datasets/kgc_search/all_entites.tsv") as fin:
    for line in fin:
        eid, text = line.strip().split("\t")
        eid_to_text[int(eid)] = text

out_dir = "/home/jupyter/jointly_rec_and_search/datasets/kgc_search/train/bm25_neg/"
if not os.path.exists(out_dir):
    os.mkdir(out_dir)
    
# start to create triples
miss_triples = 0
miss_qids = set()
duplicate_pairs = []

with open(os.path.join(out_dir, "max5_triples.train_q2p.tsv"), "w") as fout:
    with open(os.path.join(out_dir, "max2_triples.train_q2p.tsv"), "w") as fout2:
        for qid, pos_pids in tqdm(train_qid_to_pids.items(), total=len(train_qid_to_pids)):
            cand_triples = []
            if qid not in bm25_hid_to_tids:
                miss_qids.add(qid)
                continue
            for pos_pid in pos_pids:
                neg_pid = random.sample(bm25_hid_to_tids[qid], k=1)[0]
                while neg_pid == pos_pid:
                    neg_pid = random.sample(bm25_hid_to_tids[qid], k=1)[0]
                cand_triples.append((qid, pos_pid, neg_pid))
            
            for qid, pos_pid, neg_pid in cand_triples[:5]:
                fout.write(f"{qid}\t{pos_pid}\t{neg_pid}\t{REL_RELATION}\n")
            for qid, pos_pid, neg_pid in cand_triples[:2]:
                fout2.write(f"{qid}\t{pos_pid}\t{neg_pid}\t{REL_RELATION}\n")
                
print("miss_qids = {:,}".format(len(miss_qids)))

with open(os.path.join(out_dir, "max5_triples.q2p.tsv"), "w") as fout:
    with open(os.path.join(out_dir, "max3_triples.q2p.tsv"), "w") as fout2:
        with open(os.path.join(out_dir, "max1_triples.q2p.tsv"), "w") as fout3:
            for pos_pid, qids in max5_pid_to_qids.items():
                cand_triples = []
                for qid in qids[:5]:
                    if qid not in bm25_hid_to_tids:
                        miss_qids.add(qid)
                        continue
                    neg_pid = random.sample(bm25_hid_to_tids[qid], k=1)[0]
                    while neg_pid == pos_pid:
                        neg_pid = random.sample(bm25_hid_to_tids[qid], k=1)[0]
                    cand_triples.append((qid, pos_pid, neg_pid))
                    
                for qid, pos_pid, neg_pid in cand_triples[:5]:
                    fout.write(f"{qid}\t{pos_pid}\t{neg_pid}\t{REL_RELATION}\n")
                for qid, pos_pid, neg_pid in cand_triples[:3]:
                    fout2.write(f"{qid}\t{pos_pid}\t{neg_pid}\t{REL_RELATION}\n")
                for qid, pos_pid, neg_pid in cand_triples[:1]:
                    fout3.write(f"{qid}\t{pos_pid}\t{neg_pid}\t{REL_RELATION}\n")

print("miss_qids = {:,}".format(len(miss_qids)))

with open(os.path.join(out_dir, "max5_triples.p2sp.tsv"), "w") as fout:
    with open(os.path.join(out_dir, "max3_triples.p2sp.tsv"), "w") as fout2:
        with open(os.path.join(out_dir, "max1_triples.p2sp.tsv"), "w") as fout3:
            for pid, sim_pids in max5_pid_to_simpids.items():
                assert pid in bm25_hid_to_tids, pid
                cand_triples = []
                for pos_pid in sim_pids:
                    if eid_to_text[pid] == eid_to_text[pos_pid]:
                        duplicate_pairs.append((pid, pos_pid))
                        continue
                    neg_pid = random.sample(range(2_000_000), k=1)[0]
                    cand_triples.append((pid, pos_pid, neg_pid))
                    
                for pid, pos_pid, neg_pid in cand_triples[:5]:
                    fout.write(f"{pid}\t{pos_pid}\t{neg_pid}\t{SIM_RELATION}\n")
                for pid, pos_pid, neg_pid in cand_triples[:3]:
                    fout2.write(f"{pid}\t{pos_pid}\t{neg_pid}\t{SIM_RELATION}\n")
                for pid, pos_pid, neg_pid in cand_triples[:1]:
                    fout3.write(f"{pid}\t{pos_pid}\t{neg_pid}\t{SIM_RELATION}\n")
                    
with open(os.path.join(out_dir, "max5_triples.p2cp.tsv"), "w") as fout:
    with open(os.path.join(out_dir, "max3_triples.p2cp.tsv"), "w") as fout2:
        with open(os.path.join(out_dir, "max1_triples.p2cp.tsv"), "w") as fout3:
            for pid, compl_pids in max5_pid_to_complpids.items():
                assert pid in bm25_hid_to_tids, pid 
                cand_triples = []
                for pos_pid in compl_pids:
                    neg_pid = random.sample(bm25_hid_to_tids[pid], k=1)[0]
                    while neg_pid in compl_pids:
                        neg_pid = random.sample(bm25_hid_to_tids[pid], k=1)[0]
                    cand_triples.append((pid, pos_pid, neg_pid))
                
                for pid, pos_pid, neg_pid in cand_triples[:5]:
                    fout.write(f"{pid}\t{pos_pid}\t{neg_pid}\t{COMPL_RELATION}\n")
                for pid, pos_pid, neg_pid in cand_triples[:5]:
                    fout2.write(f"{pid}\t{pos_pid}\t{neg_pid}\t{COMPL_RELATION}\n")
                for pid, pos_pid, neg_pid in cand_triples[:5]:
                    fout3.write(f"{pid}\t{pos_pid}\t{neg_pid}\t{COMPL_RELATION}\n")

print("-"*75)
print("number of miss qids = {:,}".format(len(miss_qids)))
print("duplicate pairs in similar_relation = {:,}".format(len(duplicate_pairs)))

100%|██████████| 763019/763019 [00:12<00:00, 63428.17it/s]


miss_qids = 47,903
miss_qids = 47,903
---------------------------------------------------------------------------
number of miss qids = 47,903
duplicate pairs in similar_relation = 27,280


In [14]:
# sanity check
for path in os.listdir(out_dir):
    path = os.path.join(out_dir, path)
    if "max5" not in path:
        continue
    ! wc -l $path
    ! head -n 3 $path
    ! tail -n 3 $path
    print("="*100)

531202 /home/jupyter/jointly_rec_and_search/datasets/kgc_search/train/bm25_neg/max5_triples.p2sp.tsv
1909034	1624640	1606980	is_similar_to
1909034	114539	1899549	is_similar_to
1909034	1289078	1681662	is_similar_to
1659246	1061716	1348584	is_similar_to
981419	1413473	1061334	is_similar_to
981419	1035097	614898	is_similar_to
1744194 /home/jupyter/jointly_rec_and_search/datasets/kgc_search/train/bm25_neg/max5_triples.train_q2p.tsv
2798039	1909034	2013959	is_relevant_to
2308312	2039319	1162673	is_relevant_to
2743211	829128	1109469	is_relevant_to
3086481	1470444	1252548	is_relevant_to
2805636	1806689	674143	is_relevant_to
2742849	1551045	1809179	is_relevant_to
752591 /home/jupyter/jointly_rec_and_search/datasets/kgc_search/train/bm25_neg/max5_triples.q2p.tsv
2858855	1909034	1591771	is_relevant_to
2611806	1909034	280886	is_relevant_to
2265963	1909034	1318337	is_relevant_to
3085531	2025634	1603117	is_relevant_to
2733127	1243971	2256199	is_relevant_to
2733127	307097	1342821	is_relevant_to
2094

In [32]:
! grep -P "^15123\t" "/home/jupyter/jointly_rec_and_search/datasets/kgc_search/all_entites.tsv"
! grep -P "^1485892\t" "/home/jupyter/jointly_rec_and_search/datasets/kgc_search/all_entites.tsv"
! grep -P "^1948932\t" "/home/jupyter/jointly_rec_and_search/datasets/kgc_search/all_entites.tsv"

15123	Bonnie Plants Zucchini Plant in 11.8-oz Pot ; Vegetable Plants
1485892	Bonnie Plants 2-Pack Sweet Basil in 19.3-oz Pot ; Herb Plants
1948932	Clean-X Heavy Duty Cotton Mop Cotton Non-wringing String Wet Mop ; Wet Mops


In [15]:
! wc -l /home/jupyter/jointly_rec_and_search/datasets/kgc_search/train/bm25_neg/max2_triples.train_q2p.tsv

1099241 /home/jupyter/jointly_rec_and_search/datasets/kgc_search/train/bm25_neg/max2_triples.train_q2p.tsv
