In [16]:
import os
import pickle as pkl
from collections import defaultdict
import random 

from tqdm import tqdm
import networkx as nx
import pandas as pd
random.seed(4680)

in_dir = "/home/jupyter/unity_jointly_rec_and_search/datasets/amazon_esci_dataset/task_2_multiclass_product_classification/"
datas = []
         
fns = [
    "train_aid_to_simpids.pkl",
    "train_aid_to_complpids.pkl",
    "train_qid_to_relpids.pkl",
    ]
for fn in fns:
    with open(os.path.join(in_dir, fn), "rb") as fin:
        datas.append(pkl.load(fin))

train_aid_to_simpids, train_aid_to_complpids, train_qid_to_pids = datas

G = nx.MultiDiGraph()
SIM_RELATION = "is_similar_to"
COMPL_RELATION = "is_complementary_to"
REL_RELATION = "is_relevant_to"

for aid, sim_pids in train_aid_to_simpids.items():
    triples = [(aid, sim_pid, {"type":SIM_RELATION}) for sim_pid in sim_pids]
    G.add_edges_from(triples)
    
for aid, compl_pids in train_aid_to_complpids.items():
    triples = [(aid, compl_pid, {"type":COMPL_RELATION}) for compl_pid in compl_pids]
    G.add_edges_from(triples)
    
for qid, pids in train_qid_to_pids.items():
    triples = [(pid, qid, {"type": REL_RELATION}) for pid in pids]
    G.add_edges_from(triples)
    
multi_edge_pairs = []
for n, nbrs_dict in tqdm(G.adj.items(), total=G.number_of_nodes()):
    for nbr_node, edge_attrs in nbrs_dict.items():
        assert len(edge_attrs) == 1 or len(edge_attrs) == 2
        if len(edge_attrs) == 2:
            multi_edge_pairs.append((n, nbr_node))
            
print("number of edges = {:,}, number of multi-attr edges = {:,}, ({:.3f})".format(G.number_of_edges(), len(multi_edge_pairs), 
                                                                                   len(multi_edge_pairs)/G.number_of_edges()))

PIDS = []
with open(os.path.join(in_dir, "collection_title.tsv")) as fin:
    for line in fin:
        PIDS.append(int(line.strip().split("\t")[0]))
print(f"max pids = {max(PIDS)}")
    

100%|██████████| 843496/843496 [00:04<00:00, 174902.41it/s]


number of edges = 2,746,193, number of multi-attr edges = 1,090, (0.000)
max pids = 1216069


In [18]:
eid_to_text = {}
with open(os.path.join(in_dir, "all_entities.tsv")) as fin:
    for line in fin:
        eid, text = line.strip().split("\t")
        eid_to_text[int(eid)] = text
        
run_path = os.path.join(in_dir, "runs/bm25.all.run")
df = pd.read_csv(run_path, sep=" ", names=["hid", "q0", "tid", "rank", "score", "model_name"])
bm25_hid_to_tids = {}
ignore_hids = set()
for hid, group in df.groupby("hid"):
    cand_tids = list(group.tid.values)
    if len(cand_tids) < 10:
        ignore_hids.add(int(hid))
    else:
        bm25_hid_to_tids[int(hid)] = [int(x) for x in cand_tids]
        
print("number of ignore hids = {}".format(len(ignore_hids)))

number of ignore hids = 2354


In [22]:
MAX_PID = max(PIDS)
def create_mixture_triples_for_search(hid, pos_tid, hid_to_simpids_sampler, hid_to_complpids_sampler, bm25_hid_to_tid_sampler,
                                     miss_hids, hids_have_simnegs, hids_have_complnegs, hids_have_bm25negs):
    if pos_tid in hid_to_simpids_sampler:
        neg_tid = random.sample(hid_to_simpids_sampler[pos_tid], k=1)[0]
        assert neg_tid != pos_tid
        hids_have_simnegs.append(hid)
    elif pos_tid in hid_to_complpids_sampler:
        neg_tid = random.sample(hid_to_complpids_sampler[pos_tid], k=1)[0]
        assert neg_tid != pos_tid
        hids_have_complnegs.append(hid)
    elif hid in bm25_hid_to_tid_sampler:
        neg_tid = random.sample(bm25_hid_to_tid_sampler[hid], k=1)[0]
        while neg_tid == pos_tid:
            neg_tid = random.sample(bm25_hid_to_tid_sampler[hid], k=1)[0]
        hids_have_bm25negs.append(hid)
    else:
        miss_hids.append(hid)
        return 0 
    
    return (hid, pos_tid, neg_tid)

max15_q2p = {}
max10_q2p = {}
max5_q2p = {}
max2_q2p = {}

max15_q2p_triples = []
max10_q2p_triples = []
max5_q2p_triples = []
max2_q2p_triples = []
for qid, pids in train_qid_to_pids.items():
    max15_q2p[qid] = random.sample(pids, k=len(pids))[:15]
    max10_q2p[qid] = random.sample(pids, k=len(pids))[:10]
    max5_q2p[qid] = random.sample(pids, k=len(pids))[:5]
    max2_q2p[qid] = random.sample(pids, k=len(pids))[:2]
    
miss_hids, hids_have_simnegs, hids_have_complnegs, hids_have_bm25negs = [], [], [], []
total_triples = 0   
for qid, pos_pids in max15_q2p.items():
    for pos_pid in pos_pids:
        triple = create_mixture_triples_for_search(qid, pos_pid, 
                                                   hid_to_simpids_sampler=train_aid_to_simpids,
                                                   hid_to_complpids_sampler=train_aid_to_complpids, 
                                                   bm25_hid_to_tid_sampler=bm25_hid_to_tids,
                                                   miss_hids=miss_hids, 
                                                   hids_have_simnegs=hids_have_simnegs, 
                                                   hids_have_complnegs=hids_have_complnegs, 
                                                   hids_have_bm25negs=hids_have_bm25negs)
        if triple != 0:
            max15_q2p_triples.append(triple)
            total_triples += 1
print("miss_hids = {:,}, triples with simnegs = {:,}, triples with complnegs = {:,}, triples with bm25negs = {:,}".format(
    len(miss_hids), len(hids_have_simnegs), len(hids_have_complnegs), len(hids_have_bm25negs)
))
print("total triples = {:,}".format(total_triples))
print("="*75)

miss_hids, hids_have_simnegs, hids_have_complnegs, hids_have_bm25negs = [], [], [], []
total_triples = 0   
for qid, pos_pids in max10_q2p.items():
    for pos_pid in pos_pids:
        triple = create_mixture_triples_for_search(qid, pos_pid, 
                                                   hid_to_simpids_sampler=train_aid_to_simpids,
                                                   hid_to_complpids_sampler=train_aid_to_complpids, 
                                                   bm25_hid_to_tid_sampler=bm25_hid_to_tids,
                                                   miss_hids=miss_hids, 
                                                   hids_have_simnegs=hids_have_simnegs, 
                                                   hids_have_complnegs=hids_have_complnegs, 
                                                   hids_have_bm25negs=hids_have_bm25negs)
        if triple != 0:
            max10_q2p_triples.append(triple)
            total_triples += 1
print("miss_hids = {:,}, triples with simnegs = {:,}, triples with complnegs = {:,}, triples with bm25negs = {:,}".format(
    len(miss_hids), len(hids_have_simnegs), len(hids_have_complnegs), len(hids_have_bm25negs)
))
print("total triples = {:,}".format(total_triples))
print("="*75)

miss_hids, hids_have_simnegs, hids_have_complnegs, hids_have_bm25negs = set(), set(), set(), set()
total_triples = 0   
for qid, pos_pids in max5_q2p.items():
    for pos_pid in pos_pids:
        triple = create_mixture_triples_for_search(qid, pos_pid, 
                                                   hid_to_simpids_sampler=train_aid_to_simpids,
                                                   hid_to_complpids_sampler=train_aid_to_complpids, 
                                                   bm25_hid_to_tid_sampler=bm25_hid_to_tids,
                                                   miss_hids=miss_hids, 
                                                   hids_have_simnegs=hids_have_simnegs, 
                                                   hids_have_complnegs=hids_have_complnegs, 
                                                   hids_have_bm25negs=hids_have_bm25negs)
        if triple != 0:
            max5_q2p_triples.append(triple)
            total_triples += 1
print("miss_hids = {:,}, triples with simnegs = {:,}, triples with complnegs = {:,}, triples with bm25negs = {:,}".format(
    len(miss_hids), len(hids_have_simnegs), len(hids_have_complnegs), len(hids_have_bm25negs)
))
print("total triples = {:,}".format(total_triples))
print("="*75)

miss_hids, hids_have_simnegs, hids_have_complnegs, hids_have_bm25negs = set(), set(), set(), set()
total_triples = 0   
for qid, pos_pids in max2_q2p.items():
    for pos_pid in pos_pids:
        triple = create_mixture_triples_for_search(qid, pos_pid, 
                                                   hid_to_simpids_sampler=train_aid_to_simpids,
                                                   hid_to_complpids_sampler=train_aid_to_complpids, 
                                                   bm25_hid_to_tid_sampler=bm25_hid_to_tids,
                                                   miss_hids=miss_hids, 
                                                   hids_have_simnegs=hids_have_simnegs, 
                                                   hids_have_complnegs=hids_have_complnegs, 
                                                   hids_have_bm25negs=hids_have_bm25negs)
        if triple != 0:
            max2_q2p_triples.append(triple)
            total_triples += 1
print("miss_hids = {:,}, triples with simnegs = {:,}, triples with complnegs = {:,}, triples with bm25negs = {:,}".format(
    len(miss_hids), len(hids_have_simnegs), len(hids_have_complnegs), len(hids_have_bm25negs)
))
print("total triples = {:,}".format(total_triples))
print("="*75)

miss_hids = 3,697, triples with simnegs = 302,538, triples with complnegs = 31,483, triples with bm25negs = 266,550
total triples = 600,571
miss_hids = 2,834, triples with simnegs = 237,638, triples with complnegs = 23,994, triples with bm25negs = 190,436
total triples = 452,068
miss_hids = 1,643, triples with simnegs = 135,694, triples with complnegs = 13,351, triples with bm25negs = 100,286
total triples = 249,331
miss_hids = 694, triples with simnegs = 58,798, triples with complnegs = 5,662, triples with bm25negs = 41,450
total triples = 105,910


In [None]:
print("unique aids for simpids = {:,}, unique aids for complpids = {:,}".format(len(max5_h2sp), len(max5_h2cp)))
print("query oriented sampling max2, unique queries = {:,}, unique items = {:,}".format(
    len(set([q for (q, _, _) in max2_q2p_triples])), len(set([p for (_, p, _) in max2_q2p_triples]))))
print("query oriented sampling max5, unique queries = {:,}, unique items = {:,}".format(
    len(set([q for (q, _, _) in max5_q2p_triples])), len(set([p for (_, p, _) in max5_q2p_triples]))))
print("query oriented sampling max10, unique queries = {:,}, unique items = {:,}".format(
    len(set([q for (q, _, _) in max10_q2p_triples])), len(set([p for (_, p, _) in max10_q2p_triples]))))
print("query oriented sampling max15, unique queries = {:,}, unique items = {:,}".format(
    len(set([q for (q, _, _) in max15_q2p_triples])), len(set([p for (_, p, _) in max15_q2p_triples]))))

print("h2sp triples = {:,}, h2cp triples = {:,}".format(len(h2sp_triples), len(h2cp_triples)))
print("max2_q2p_triples = {:,}, max5_q2p _triples = {:,}, max10_q2p_triples = {:,}, max15_q2p_triples = {:,}".format(
    len(max2_q2p_triples), len(max5_q2p_triples), len(max10_q2p_triples), len(max15_q2p_triples)
))total_triples

unique aids for simpids = 315,575, unique aids for complpids = 82,580
query oriented sampling max2, unique queries = 54,237, unique items = 101,142
query oriented sampling max5, unique queries = 54,310, unique items = 226,564
query oriented sampling max10, unique queries = 54,331, unique items = 386,440
query oriented sampling max15, unique queries = 54,338, unique items = 491,323


In [6]:
import pickle

out_dir = os.path.join(in_dir, "unified_train/")
if not os.path.exists(out_dir):
    os.mkdir(out_dir)

with open(os.path.join(out_dir, "train_graph.pkl"), "wb") as fout:
    pickle.dump(G, fout)

fn_to_tripleNrel = {
    "a2sp.train.tsv": (h2sp_triples, SIM_RELATION),
    "a2cp.train.tsv": (h2cp_triples, COMPL_RELATION),
    "max2_qorient_q2p.train.tsv": (max2_q2p_triples, REL_RELATION),
    "max5_qorient_q2p.train.tsv": (max5_q2p_triples, REL_RELATION),
    "max10_qorient_q2p.train.tsv": (max10_q2p_triples, REL_RELATION),
    "max15_qorient_q2p.train.tsv": (max10_q2p_triples, REL_RELATION),
    
    "a2sp.50.train.tsv": (random.sample(h2sp_triples, k=int(0.5*len(h2sp_triples))), SIM_RELATION),
    "a2sp.25.train.tsv": (random.sample(h2sp_triples, k=int(0.25*len(h2sp_triples))), SIM_RELATION)
}

for fn, (triples, relation) in fn_to_tripleNrel.items():
    with open(os.path.join(out_dir, fn), "w") as fout:
        for (hid, pos_tid, neg_tid) in triples:
            fout.write(f"{hid}\t{pos_tid}\t{neg_tid}\t{relation}\n")
            
out_dir = os.path.join(in_dir, "unified_test/")
if not os.path.exists(out_dir):
    os.mkdir(out_dir)

# for similar items
fn_to_aids = {
    "anchors.train.sim.tsv": list(train_aid_to_simpids.keys()),
    "anchors.val.sim.tsv": list(val_aid_to_simpids.keys()),
    "anchors.test.sim.tsv": list(test_aid_to_simpids.keys()),
}
for fn, aids in fn_to_aids.items():
    with open(os.path.join(out_dir, fn), "w") as fout:
        for aid in aids:
            text = eid_to_text[aid]
            fout.write(f"{aid}\t{text}\t{SIM_RELATION}\n")
            
fn_to_arels = {
    "arels.train.sim.tsv": [(aid, pid) for aid, simpids in train_aid_to_simpids.items() for pid in simpids],
    "arels.val.sim.tsv": [(aid, pid) for aid, simpids in val_aid_to_simpids.items() for pid in simpids],
    "arels.test.sim.tsv": [(aid, pid) for aid, simpids in test_aid_to_simpids.items() for pid in simpids],
}
for fn, arels in fn_to_arels.items():
    with open(os.path.join(out_dir, fn), "w") as fout:
        for (aid, pid) in arels:
            fout.write(f"{aid}\tQ0\t{pid}\t{1}\n")
            
# for complementary items
fn_to_aids = {
    "anchors.train.compl.tsv": list(train_aid_to_complpids.keys()),
    "anchors.val.compl.tsv": list(val_aid_to_complpids.keys()),
    "anchors.test.compl.tsv": list(test_aid_to_complpids.keys()),
}
for fn, aids in fn_to_aids.items():
    with open(os.path.join(out_dir, fn), "w") as fout:
        for aid in aids:
            text = eid_to_text[aid]
            fout.write(f"{aid}\t{text}\t{COMPL_RELATION}\n")
fn_to_arels = {
    "arels.train.compl.tsv": [(aid, pid) for aid, complpids in train_aid_to_complpids.items() for pid in complpids],
    "arels.val.compl.tsv": [(aid, pid) for aid, complpids in val_aid_to_complpids.items() for pid in complpids],
    "arels.test.compl.tsv": [(aid, pid) for aid, complpids in test_aid_to_complpids.items() for pid in complpids]
}
for fn, arels in fn_to_arels.items():
    with open(os.path.join(out_dir, fn), "w") as fout:
        for (aid, pid) in arels:
            fout.write(f"{aid}\tQ0\t{pid}\t{1}\n")
            
# for queries
fn_to_qids = {
    "queries.train.tsv": list(train_qid_to_pids.keys()),
    "queries.val.tsv": list(val_qid_to_pids.keys()),
    "queries.test.tsv": list(test_qid_to_pids.keys()),
}
for fn, qids in fn_to_qids.items():
    with open(os.path.join(out_dir, fn), "w") as fout:
        for qid in qids:
            text = eid_to_text[qid]
            fout.write(f"{qid}\t{text}\t{REL_RELATION}\n")
            
            
    
            
fn_to_qrels = {
    "qrels.train.tsv": [(qid, pid) for qid, pids in train_qid_to_pids.items() for pid in pids],
    "qrels.val.tsv": [(qid, pid) for qid, pids in val_qid_to_pids.items() for pid in pids],
    "qrels.test.tsv": [(qid, pid) for (qid, pids) in test_qid_to_pids.items() for pid in pids],
}

for fn, qrels in fn_to_qrels.items():
    with open(os.path.join(out_dir, fn), "w") as fout:
        for (qid, pid) in qrels:
            fout.write(f"{qid}\tQ0\t{pid}\t{1}\n")

In [8]:
# sanity check
import os
in_dir="/home/jupyter/unity_jointly_rec_and_search/datasets/amazon_esci_dataset/task_2_multiclass_product_classification/"

eid_to_text = {}
with open(os.path.join(in_dir, "all_entities.tsv")) as fin:
    for line in fin:
        eid, text = line.strip().split("\t")
        eid_to_text[int(eid)] = text
        
out_dir = os.path.join(in_dir, "unified_train/")
for path in os.listdir(out_dir):
    path = os.path.join(out_dir, path)
    if path.endswith("tsv"):
        ! wc -l $path
        ! head -n 3 $path
        ! tail -n 3 $path
        print("="*100)

448988 /home/jupyter/unity_jointly_rec_and_search/datasets/amazon_esci_dataset/task_2_multiclass_product_classification/unified_train/max15_qorient_q2p.train.tsv
1267636	517032	943546	is_relevant_to
1267636	517040	145583	is_relevant_to
1267636	517038	1131982	is_relevant_to
1281425	695445	1195271	is_relevant_to
1281425	971808	465588	is_relevant_to
1281425	695443	643557	is_relevant_to
600567 /home/jupyter/unity_jointly_rec_and_search/datasets/amazon_esci_dataset/task_2_multiclass_product_classification/unified_train/max15_qorient_q2p.scbnegs.train.tsv
1267636	517038	515292	is_relevant_to
1267636	1093650	517038	is_relevant_to
1267636	517040	1141802	is_relevant_to
1281425	971808	1195242	is_relevant_to
1281425	695448	775789	is_relevant_to
1281425	953046	775794	is_relevant_to
533871 /home/jupyter/unity_jointly_rec_and_search/datasets/amazon_esci_dataset/task_2_multiclass_product_classification/unified_train/a2sp.50.train.tsv
721795	721774	19743	is_similar_to
634415	634416	1040981	is_similar_

In [15]:
hid, pos_tid, neg_tid = (14 ,953828,973116)
print(eid_to_text[hid])
print(eid_to_text[pos_tid])
print(eid_to_text[neg_tid])

BIC Xtra Fun Cased Pencil, 2 Lead, Assorted Barrel Colors, 48-Count
Wood-Cased #2 HB Pencils, Yellow, Pre-sharpened, Class Pack, 1000 pencils
Meet Perfect Office Chair with Wheels Cheap Desk Chair Swivel Task Chairs Mid Back No Arms Ergonomic Mesh Chair Small Computer Chair Modern Height Adjustable for Home Space-Saving, Black
