In [1]:
import os
import pickle as pkl
from collections import defaultdict
import random 

from tqdm import tqdm
import networkx as nx
import pandas as pd
random.seed(4680)

in_dir = "/home/jupyter/unity_jointly_rec_and_search/datasets/amazon_esci_dataset/data/processed/public/task_1_query_product_ranking/"
datas = []
         
fns = [
    "train_aid_to_simpids.pkl",
    "val_aid_to_simpids.pkl",
    "test_aid_to_simpids.pkl",
    
    "train_aid_to_complpids.pkl",
    "val_aid_to_complpids.pkl",
    "test_aid_to_complpids",
    
    "train_qid_to_relpids.pkl",
    "val_qid_to_relpids.pkl",
    "test_qid_to_relpids",
    ]
for fn in fns:
    with open(os.path.join(in_dir, fn), "rb") as fin:
        datas.append(pkl.load(fin))

train_aid_to_simpids, val_aid_to_simpids,test_aid_to_simpids, train_aid_to_complpids, val_aid_to_complpids, test_aid_to_complpids, train_qid_to_pids, val_qid_to_pids,test_qid_to_pids = datas

G = nx.MultiDiGraph()
SIM_RELATION = "is_similar_to"
COMPL_RELATION = "is_complementary_to"
REL_RELATION = "is_relevant_to"

for aid, sim_pids in train_aid_to_simpids.items():
    triples = [(aid, sim_pid, {"type":SIM_RELATION}) for sim_pid in sim_pids]
    G.add_edges_from(triples)
    
for aid, compl_pids in train_aid_to_complpids.items():
    triples = [(aid, compl_pid, {"type":COMPL_RELATION}) for compl_pid in compl_pids]
    G.add_edges_from(triples)
    
for qid, pids in train_qid_to_pids.items():
    triples = [(pid, qid, {"type": REL_RELATION}) for pid in pids]
    G.add_edges_from(triples)
    
multi_edge_pairs = []
for n, nbrs_dict in tqdm(G.adj.items(), total=G.number_of_nodes()):
    for nbr_node, edge_attrs in nbrs_dict.items():
        assert len(edge_attrs) == 1 or len(edge_attrs) == 2
        if len(edge_attrs) == 2:
            multi_edge_pairs.append((n, nbr_node))
            
print("number of edges = {:,}, number of multi-attr edges = {:,}, ({:.3f})".format(G.number_of_edges(), len(multi_edge_pairs), 
                                                                                   len(multi_edge_pairs)/G.number_of_edges()))

100%|██████████| 308664/308664 [00:01<00:00, 160232.34it/s]


number of edges = 1,245,760, number of multi-attr edges = 493, (0.000)


In [2]:
def create_triples(hid, pos_tid, miss_hids, duplicate_pairs, eid_to_text, sampler=None):
    if sampler != None:
        assert type(sampler) == dict, type(sampler)
        if hid not in sampler:
            miss_hids.append(hid)
            return 0
    if eid_to_text[hid] == eid_to_text[pos_tid]:
        duplicate_pairs.append((hid, pos_tid))
        return 0
    
    if sampler != None:
        neg_tid = random.sample(sampler[hid], k=1)[0]
        while neg_tid == pos_tid:
            neg_tid = random.sample(sampler[hid], k=1)[0]
    else:
        neg_tid = random.sample(range(480_000), k=1)[0]
        while neg_tid == pos_tid:
            neg_tid = random.sample(range(480_000), k=1)[0]
            
    return (hid, pos_tid, neg_tid)


eid_to_text = {}
with open(os.path.join(in_dir, "all_entities.tsv")) as fin:
    for line in fin:
        eid, text = line.strip().split("\t")
        eid_to_text[int(eid)] = text
        
run_path = os.path.join(in_dir, "runs/bm25.all.run")
df = pd.read_csv(run_path, sep=" ", names=["hid", "q0", "tid", "rank", "score", "model_name"])
bm25_hid_to_tids = {}
ignore_hids = set()
for hid, group in df.groupby("hid"):
    cand_tids = list(group.tid.values)
    if len(cand_tids) < 10:
        ignore_hids.add(int(hid))
    else:
        bm25_hid_to_tids[int(hid)] = [int(x) for x in cand_tids]
        
print("number of ignore hids = {}".format(len(ignore_hids)))

number of ignore hids = 1242


In [3]:
max5_h2sp = {}
max5_h2cp = {}
max5_h2q = {}

for head_node, nbrs_dict in tqdm(G.adj.items(), total=G.number_of_nodes()):
    sim_pids = []
    compl_pids = []
    rel_qids = []
    for tail_node, edge_attrs in nbrs_dict.items():
        assert len(edge_attrs) == 1 or len(edge_attrs) == 2
        relations = []
        for no, edge_attr in edge_attrs.items():
            relations.append(edge_attr["type"])
        for rel in relations:
            assert rel in [SIM_RELATION, COMPL_RELATION, REL_RELATION]
            if rel in SIM_RELATION:
                sim_pids.append(tail_node)
            if rel in COMPL_RELATION:
                compl_pids.append(tail_node)
            if rel in REL_RELATION:
                rel_qids.append(tail_node)
    if len(sim_pids) != 0:
        max5_h2sp[head_node] = random.sample(sim_pids, k=len(sim_pids))[:5]
    if len(compl_pids) != 0:
        max5_h2cp[head_node] = random.sample(compl_pids, k=len(compl_pids))[:5]
    if len(rel_qids) != 0:
        max5_h2q[head_node] = random.sample(rel_qids, k=len(rel_qids))[:5]
        
miss_hids = []
duplicate_pairs = []

h2sp_triples = []
h2cp_triples = []
q2h_triples = []
for hid, tail_ids in max5_h2sp.items():
    for pos_tid in tail_ids:
        triple = create_triples(hid, pos_tid, miss_hids, duplicate_pairs, eid_to_text)
        if triple != 0:
            h2sp_triples.append(triple)
print("miss_hids = {:,}, duplicate_pairs = {:,}".format(len(miss_hids), len(duplicate_pairs)))
print("="*75)
for hid, tail_ids in max5_h2cp.items():
    for pos_tid in tail_ids:
        triple = create_triples(hid, pos_tid, miss_hids, duplicate_pairs, eid_to_text, sampler=bm25_hid_to_tids)
        if triple != 0:
            h2cp_triples.append(triple)
print("miss_hids = {:,}, duplicate_pairs = {:,}".format(len(miss_hids), len(duplicate_pairs)))
print("="*75)
for pos_tid, head_ids in max5_h2q.items():
    for hid in head_ids:
        triple = create_triples(hid, pos_tid, miss_hids, duplicate_pairs, eid_to_text, sampler=bm25_hid_to_tids)
        if triple != 0:
            q2h_triples.append(triple)
print("miss_hids = {:,}, duplicate_pairs = {:,}".format(len(miss_hids), len(duplicate_pairs)))
print("="*75)

100%|██████████| 308664/308664 [00:05<00:00, 53787.87it/s] 


miss_hids = 0, duplicate_pairs = 206
miss_hids = 78, duplicate_pairs = 220
miss_hids = 2,598, duplicate_pairs = 221


In [4]:
max5_q2p = {}
max2_q2p = {}
max5_q2p_triples = []
max2_q2p_triples = []
for qid, pids in train_qid_to_pids.items():
    max5_q2p[qid] = random.sample(pids, k=len(pids))[:5]
    max2_q2p[qid] = random.sample(pids, k=len(pids))[:2]
    
for qid, pos_pids in max5_q2p.items():
    for pos_pid in pos_pids:
        triple = create_triples(qid, pos_pid, miss_hids, duplicate_pairs, eid_to_text, sampler=bm25_hid_to_tids)
        if triple != 0:
            max5_q2p_triples.append(triple)
print("miss_hids = {:,}, duplicate_pairs = {:,}".format(len(miss_hids), len(duplicate_pairs)))
print("="*75)

for qid, pos_pids in max2_q2p.items():
    for pos_pid in pos_pids:
        triple = create_triples(qid, pos_pid, miss_hids, duplicate_pairs, eid_to_text, sampler=bm25_hid_to_tids)
        if triple != 0:
            max2_q2p_triples.append(triple)
print("miss_hids = {:,}, duplicate_pairs = {:,}".format(len(miss_hids), len(duplicate_pairs)))
print("="*75)

miss_hids = 3,949, duplicate_pairs = 222
miss_hids = 4,562, duplicate_pairs = 223


In [5]:
print("query oriented sampling max2, unique queries = {:,}, unique items = {:,}".format(
    len(set([q for (q, _, _) in max2_q2p_triples])), len(set([p for (_, p, _) in max2_q2p_triples]))))
print("query oriented sampling max5, unique queries = {:,}, unique items = {:,}".format(
    len(set([q for (q, _, _) in max5_q2p_triples])), len(set([p for (_, p, _) in max5_q2p_triples]))))
print("item oriented sampling, unique queries = {:,}, unique items = {:,}".format(
    len(set([q for (q, _, _) in q2h_triples])), len(set([p for (_, p, _) in q2h_triples]))))

query oriented sampling max2, unique queries = 16,392, unique items = 31,144
query oriented sampling max5, unique queries = 16,392, unique items = 67,724
item oriented sampling, unique queries = 16,389, unique items = 134,228


In [6]:
len(h2sp_triples), len(h2cp_triples), len(q2h_triples), len(max2_q2p_triples), len(max5_q2p_triples), \
len(test_qid_to_pids), len(test_aid_to_simpids), len(test_aid_to_complpids)

(481721, 101280, 143317, 31664, 70024, 2089, 14630, 4884)

In [7]:
import pickle

out_dir = os.path.join(in_dir, "unified_train/")
if not os.path.exists(out_dir):
    os.mkdir(out_dir)

with open(os.path.join(out_dir, "train_graph.pkl"), "wb") as fout:
    pickle.dump(G, fout)

fn_to_tripleNrel = {
    "a2sp.train.tsv": (h2sp_triples, SIM_RELATION),
    "a2cp.train.tsv": (h2cp_triples, COMPL_RELATION),
    "q2a.train.tsv": (q2h_triples, REL_RELATION),
    "max2_qorient_q2p.train.tsv": (max2_q2p_triples, REL_RELATION),
    "max5_qorient_q2p.train.tsv": (max5_q2p_triples, REL_RELATION),
    
    "q2a.50.train.tsv": (random.sample(q2h_triples, k=int(0.5*len(q2h_triples))), REL_RELATION),
    "q2a.17.train.tsv": (random.sample(q2h_triples, k=int(0.17*len(q2h_triples))), REL_RELATION),
    
    "a2sp.50.train.tsv": (random.sample(h2sp_triples, k=int(0.5*len(h2sp_triples))), SIM_RELATION)
}

for fn, (triples, relation) in fn_to_tripleNrel.items():
    with open(os.path.join(out_dir, fn), "w") as fout:
        for (hid, pos_tid, neg_tid) in triples:
            fout.write(f"{hid}\t{pos_tid}\t{neg_tid}\t{relation}\n")
            
out_dir = os.path.join(in_dir, "unified_test/")
if not os.path.exists(out_dir):
    os.mkdir(out_dir)

# for similar items
fn_to_aids = {
    "anchors.train.sim.tsv": list(train_aid_to_simpids.keys()),
    "anchors.val.sim.tsv": list(val_aid_to_simpids.keys()),
    "anchors.test.sim.tsv": list(test_aid_to_simpids.keys()),
}
for fn, aids in fn_to_aids.items():
    with open(os.path.join(out_dir, fn), "w") as fout:
        for aid in aids:
            text = eid_to_text[aid]
            fout.write(f"{aid}\t{text}\t{SIM_RELATION}\n")
            
fn_to_arels = {
    "arels.train.sim.tsv": [(aid, pid) for aid, simpids in train_aid_to_simpids.items() for pid in simpids],
    "arels.val.sim.tsv": [(aid, pid) for aid, simpids in val_aid_to_simpids.items() for pid in simpids],
    "arels.test.sim.tsv": [(aid, pid) for aid, simpids in test_aid_to_simpids.items() for pid in simpids],
}
for fn, arels in fn_to_arels.items():
    with open(os.path.join(out_dir, fn), "w") as fout:
        for (aid, pid) in arels:
            fout.write(f"{aid}\tQ0\t{pid}\t{1}\n")
            
# for complementary items
fn_to_aids = {
    "anchors.train.compl.tsv": list(train_aid_to_complpids.keys()),
    "anchors.val.compl.tsv": list(val_aid_to_complpids.keys()),
    "anchors.test.compl.tsv": list(test_aid_to_complpids.keys()),
}
for fn, aids in fn_to_aids.items():
    with open(os.path.join(out_dir, fn), "w") as fout:
        for aid in aids:
            text = eid_to_text[aid]
            fout.write(f"{aid}\t{text}\t{COMPL_RELATION}\n")
fn_to_arels = {
    "arels.train.compl.tsv": [(aid, pid) for aid, complpids in train_aid_to_complpids.items() for pid in complpids],
    "arels.val.compl.tsv": [(aid, pid) for aid, complpids in val_aid_to_complpids.items() for pid in complpids],
    "arels.test.compl.tsv": [(aid, pid) for aid, complpids in test_aid_to_complpids.items() for pid in complpids]
}
for fn, arels in fn_to_arels.items():
    with open(os.path.join(out_dir, fn), "w") as fout:
        for (aid, pid) in arels:
            fout.write(f"{aid}\tQ0\t{pid}\t{1}\n")
            
# for queries
fn_to_qids = {
    "queries.train.tsv": list(train_qid_to_pids.keys()),
    "queries.val.tsv": list(val_qid_to_pids.keys()),
    "queries.test.tsv": list(test_qid_to_pids.keys()),
}
for fn, qids in fn_to_qids.items():
    with open(os.path.join(out_dir, fn), "w") as fout:
        for qid in qids:
            text = eid_to_text[qid]
            fout.write(f"{qid}\t{text}\t{REL_RELATION}\n")
            
            
    
            
fn_to_qrels = {
    "qrels.train.tsv": [(qid, pid) for qid, pids in train_qid_to_pids.items() for pid in pids],
    "qrels.val.tsv": [(qid, pid) for qid, pids in val_qid_to_pids.items() for pid in pids],
    "qrels.test.tsv": [(qid, pid) for (qid, pids) in test_qid_to_pids.items() for pid in pids],
}

for fn, qrels in fn_to_qrels.items():
    with open(os.path.join(out_dir, fn), "w") as fout:
        for (qid, pid) in qrels:
            fout.write(f"{qid}\tQ0\t{pid}\t{1}\n")

In [8]:
# sanity check
out_dir = os.path.join(in_dir, "unified_train/")
for path in os.listdir(out_dir):
    path = os.path.join(out_dir, path)
    ! wc -l $path
    ! head -n 3 $path
    ! tail -n 3 $path
    print("="*100)

240860 /home/jupyter/unity_jointly_rec_and_search/datasets/amazon_esci_dataset/data/processed/public/task_1_query-product_ranking/unified_train/a2sp.50.train.tsv
224793	224795	451839	is_similar_to
444401	404653	47706	is_similar_to
6061	6060	39532	is_similar_to
256329	467741	356894	is_similar_to
161768	2820	160752	is_similar_to
292196	394419	371576	is_similar_to
70024 /home/jupyter/unity_jointly_rec_and_search/datasets/amazon_esci_dataset/data/processed/public/task_1_query-product_ranking/unified_train/max5_qorient_q2p.train.tsv
483187	13535	403042	is_relevant_to
483187	13530	378779	is_relevant_to
483187	13528	334862	is_relevant_to
487692	70931	98065	is_relevant_to
494832	165332	2691	is_relevant_to
494832	165328	360286	is_relevant_to
155080 /home/jupyter/unity_jointly_rec_and_search/datasets/amazon_esci_dataset/data/processed/public/task_1_query-product_ranking/unified_train/train_graph.pkl
�cnetworkx.classes.multidigraph
MultiDiGraph
q )�q}q(X   edge_key_dict_factoryqcbuiltins
 uJ

In [13]:
hid, pos_tid, neg_tid = (0,263895,330719)
print(eid_to_text[hid])
print(eid_to_text[pos_tid])
print(eid_to_text[neg_tid])

Amazon Basics Woodcased #2 Pencils, Unsharpened, HB Lead - Box of 144, Bulk Box
Environmentally friendly black wood pencils, (30 pcs per barrel) triangle grip pen design, graphite HB lead core with eraser, suitable for children and adults to write sketches and paint
Going, Going Gong / Clever Levers
