In [2]:
import os
import pickle as pkl
from collections import defaultdict
import random 

from tqdm import tqdm
import networkx as nx
import pandas as pd
random.seed(4680)

in_dir = "/home/jupyter/unity_jointly_rec_and_search/datasets/amazon_esci_dataset/task_2_multiclass_product_classification/"
datas = []
         
fns = [
    "train_aid_to_simpids.pkl",
    "train_aid_to_complpids.pkl",
    "train_qid_to_relpids.pkl",
    ]
for fn in fns:
    with open(os.path.join(in_dir, fn), "rb") as fin:
        datas.append(pkl.load(fin))

train_aid_to_simpids, train_aid_to_complpids, train_qid_to_pids = datas

G = nx.MultiDiGraph()
SIM_RELATION = "is_similar_to"
COMPL_RELATION = "is_complementary_to"
REL_RELATION = "is_relevant_to"

for aid, sim_pids in train_aid_to_simpids.items():
    triples = [(aid, sim_pid, {"type":SIM_RELATION}) for sim_pid in sim_pids]
    G.add_edges_from(triples)
    
for aid, compl_pids in train_aid_to_complpids.items():
    triples = [(aid, compl_pid, {"type":COMPL_RELATION}) for compl_pid in compl_pids]
    G.add_edges_from(triples)
    
for qid, pids in train_qid_to_pids.items():
    triples = [(pid, qid, {"type": REL_RELATION}) for pid in pids]
    G.add_edges_from(triples)
    
multi_edge_pairs = []
for n, nbrs_dict in tqdm(G.adj.items(), total=G.number_of_nodes()):
    for nbr_node, edge_attrs in nbrs_dict.items():
        assert len(edge_attrs) == 1 or len(edge_attrs) == 2
        if len(edge_attrs) == 2:
            multi_edge_pairs.append((n, nbr_node))
            
print("number of edges = {:,}, number of multi-attr edges = {:,}, ({:.3f})".format(G.number_of_edges(), len(multi_edge_pairs), 
                                                                                   len(multi_edge_pairs)/G.number_of_edges()))

PIDS = []
with open(os.path.join(in_dir, "collection_title.tsv")) as fin:
    for line in fin:
        PIDS.append(int(line.strip().split("\t")[0]))
print(f"max pids = {max(PIDS)}")
    

100%|██████████| 843496/843496 [00:04<00:00, 172645.60it/s]


number of edges = 2,746,193, number of multi-attr edges = 1,090, (0.000)
max pids = 1216069


In [3]:
eid_to_text = {}
with open(os.path.join(in_dir, "all_entities.tsv")) as fin:
    for line in fin:
        eid, text = line.strip().split("\t")
        eid_to_text[int(eid)] = text
        
run_path = os.path.join(in_dir, "runs/bm25.all.run")
df = pd.read_csv(run_path, sep=" ", names=["hid", "q0", "tid", "rank", "score", "model_name"])
bm25_hid_to_tids = {}
ignore_hids = set()
for hid, group in df.groupby("hid"):
    cand_tids = list(group.tid.values)
    if len(cand_tids) < 10:
        ignore_hids.add(int(hid))
    else:
        bm25_hid_to_tids[int(hid)] = [int(x) for x in cand_tids]
        
print("number of ignore hids = {}".format(len(ignore_hids)))

number of ignore hids = 2354


In [5]:
MAX_PID = max(PIDS)
def create_mixture_triples_for_search(hid, pos_tid, hid_to_simpids_sampler, hid_to_complpids_sampler, bm25_hid_to_tid_sampler,
                                     miss_hids, hids_have_simnegs, hids_have_complnegs, hids_have_bm25negs):
    if pos_tid in hid_to_simpids_sampler:
        neg_tid = random.sample(hid_to_simpids_sampler[pos_tid], k=1)[0]
        assert neg_tid != pos_tid
        hids_have_simnegs.append(hid)
    elif pos_tid in hid_to_complpids_sampler:
        neg_tid = random.sample(hid_to_complpids_sampler[pos_tid], k=1)[0]
        assert neg_tid != pos_tid
        hids_have_complnegs.append(hid)
    elif hid in bm25_hid_to_tid_sampler:
        neg_tid = random.sample(bm25_hid_to_tid_sampler[hid], k=1)[0]
        while neg_tid == pos_tid:
            neg_tid = random.sample(bm25_hid_to_tid_sampler[hid], k=1)[0]
        hids_have_bm25negs.append(hid)
    else:
        miss_hids.append(hid)
        return 0 
    
    return (hid, pos_tid, neg_tid)

max15_q2p = {}
max10_q2p = {}
max5_q2p = {}
max2_q2p = {}

max15_q2p_triples = []
max10_q2p_triples = []
max5_q2p_triples = []
max2_q2p_triples = []
for qid, pids in train_qid_to_pids.items():
    max15_q2p[qid] = random.sample(pids, k=len(pids))[:15]
    max10_q2p[qid] = random.sample(pids, k=len(pids))[:10]
    max5_q2p[qid] = random.sample(pids, k=len(pids))[:5]
    max2_q2p[qid] = random.sample(pids, k=len(pids))[:2]
    
miss_hids, hids_have_simnegs, hids_have_complnegs, hids_have_bm25negs = [], [], [], []
total_triples = 0   
for qid, pos_pids in max15_q2p.items():
    for pos_pid in pos_pids:
        triple = create_mixture_triples_for_search(qid, pos_pid, 
                                                   hid_to_simpids_sampler=train_aid_to_simpids,
                                                   hid_to_complpids_sampler=train_aid_to_complpids, 
                                                   bm25_hid_to_tid_sampler=bm25_hid_to_tids,
                                                   miss_hids=miss_hids, 
                                                   hids_have_simnegs=hids_have_simnegs, 
                                                   hids_have_complnegs=hids_have_complnegs, 
                                                   hids_have_bm25negs=hids_have_bm25negs)
        if triple != 0:
            max15_q2p_triples.append(triple)
            total_triples += 1
print("miss_hids = {:,}, triples with simnegs = {:,}, triples with complnegs = {:,}, triples with bm25negs = {:,}".format(
    len(miss_hids), len(hids_have_simnegs), len(hids_have_complnegs), len(hids_have_bm25negs)
))
print("total triples = {:,}".format(total_triples))
print("="*75)

miss_hids, hids_have_simnegs, hids_have_complnegs, hids_have_bm25negs = [], [], [], []
total_triples = 0   
for qid, pos_pids in max10_q2p.items():
    for pos_pid in pos_pids:
        triple = create_mixture_triples_for_search(qid, pos_pid, 
                                                   hid_to_simpids_sampler=train_aid_to_simpids,
                                                   hid_to_complpids_sampler=train_aid_to_complpids, 
                                                   bm25_hid_to_tid_sampler=bm25_hid_to_tids,
                                                   miss_hids=miss_hids, 
                                                   hids_have_simnegs=hids_have_simnegs, 
                                                   hids_have_complnegs=hids_have_complnegs, 
                                                   hids_have_bm25negs=hids_have_bm25negs)
        if triple != 0:
            max10_q2p_triples.append(triple)
            total_triples += 1
print("miss_hids = {:,}, triples with simnegs = {:,}, triples with complnegs = {:,}, triples with bm25negs = {:,}".format(
    len(miss_hids), len(hids_have_simnegs), len(hids_have_complnegs), len(hids_have_bm25negs)
))
print("total triples = {:,}".format(total_triples))
print("="*75)

miss_hids, hids_have_simnegs, hids_have_complnegs, hids_have_bm25negs = [], [], [], []
total_triples = 0   
for qid, pos_pids in max5_q2p.items():
    for pos_pid in pos_pids:
        triple = create_mixture_triples_for_search(qid, pos_pid, 
                                                   hid_to_simpids_sampler=train_aid_to_simpids,
                                                   hid_to_complpids_sampler=train_aid_to_complpids, 
                                                   bm25_hid_to_tid_sampler=bm25_hid_to_tids,
                                                   miss_hids=miss_hids, 
                                                   hids_have_simnegs=hids_have_simnegs, 
                                                   hids_have_complnegs=hids_have_complnegs, 
                                                   hids_have_bm25negs=hids_have_bm25negs)
        if triple != 0:
            max5_q2p_triples.append(triple)
            total_triples += 1
print("miss_hids = {:,}, triples with simnegs = {:,}, triples with complnegs = {:,}, triples with bm25negs = {:,}".format(
    len(miss_hids), len(hids_have_simnegs), len(hids_have_complnegs), len(hids_have_bm25negs)
))
print("total triples = {:,}".format(total_triples))
print("="*75)

miss_hids, hids_have_simnegs, hids_have_complnegs, hids_have_bm25negs = [], [], [], []
total_triples = 0   
for qid, pos_pids in max2_q2p.items():
    for pos_pid in pos_pids:
        triple = create_mixture_triples_for_search(qid, pos_pid, 
                                                   hid_to_simpids_sampler=train_aid_to_simpids,
                                                   hid_to_complpids_sampler=train_aid_to_complpids, 
                                                   bm25_hid_to_tid_sampler=bm25_hid_to_tids,
                                                   miss_hids=miss_hids, 
                                                   hids_have_simnegs=hids_have_simnegs, 
                                                   hids_have_complnegs=hids_have_complnegs, 
                                                   hids_have_bm25negs=hids_have_bm25negs)
        if triple != 0:
            max2_q2p_triples.append(triple)
            total_triples += 1
print("miss_hids = {:,}, triples with simnegs = {:,}, triples with complnegs = {:,}, triples with bm25negs = {:,}".format(
    len(miss_hids), len(hids_have_simnegs), len(hids_have_complnegs), len(hids_have_bm25negs)
))
print("total triples = {:,}".format(total_triples))
print("="*75)

miss_hids = 3,701, triples with simnegs = 302,568, triples with complnegs = 31,493, triples with bm25negs = 266,506
total triples = 600,567
miss_hids = 2,820, triples with simnegs = 237,647, triples with complnegs = 24,020, triples with bm25negs = 190,415
total triples = 452,082
miss_hids = 1,593, triples with simnegs = 135,917, triples with complnegs = 13,282, triples with bm25negs = 100,182
total triples = 249,381
miss_hids = 698, triples with simnegs = 58,778, triples with complnegs = 5,707, triples with bm25negs = 41,421
total triples = 105,906


In [7]:
#print("unique aids for simpids = {:,}, unique aids for complpids = {:,}".format(len(max5_h2sp), len(max5_h2cp)))
print("query oriented sampling max2, unique queries = {:,}, unique items = {:,}".format(
    len(set([q for (q, _, _) in max2_q2p_triples])), len(set([p for (_, p, _) in max2_q2p_triples]))))
print("query oriented sampling max5, unique queries = {:,}, unique items = {:,}".format(
    len(set([q for (q, _, _) in max5_q2p_triples])), len(set([p for (_, p, _) in max5_q2p_triples]))))
print("query oriented sampling max10, unique queries = {:,}, unique items = {:,}".format(
    len(set([q for (q, _, _) in max10_q2p_triples])), len(set([p for (_, p, _) in max10_q2p_triples]))))
print("query oriented sampling max15, unique queries = {:,}, unique items = {:,}".format(
    len(set([q for (q, _, _) in max15_q2p_triples])), len(set([p for (_, p, _) in max15_q2p_triples]))))

#print("h2sp triples = {:,}, h2cp triples = {:,}".format(len(h2sp_triples), len(h2cp_triples)))
#print("max2_q2p_triples = {:,}, max5_q2p _triples = {:,}, max10_q2p_triples = {:,}, max15_q2p_triples = {:,}".format(
#    len(max2_q2p_triples), len(max5_q2p_triples), len(max10_q2p_triples), len(max15_q2p_triples)
#))total_triples

query oriented sampling max2, unique queries = 54,243, unique items = 101,144
query oriented sampling max5, unique queries = 54,305, unique items = 226,267
query oriented sampling max10, unique queries = 54,330, unique items = 386,333
query oriented sampling max15, unique queries = 54,338, unique items = 491,241


In [8]:
import pickle

out_dir = os.path.join(in_dir, "unified_train/")
if not os.path.exists(out_dir):
    os.mkdir(out_dir)

fn_to_tripleNrel = {
    "max2_qorient_q2p.scbnegs.train.tsv": (max2_q2p_triples, REL_RELATION),
    "max5_qorient_q2p.scbnegs.train.tsv": (max5_q2p_triples, REL_RELATION),
    "max10_qorient_q2p.scbnegs.train.tsv": (max10_q2p_triples, REL_RELATION),
    "max15_qorient_q2p.scbnegs.train.tsv": (max15_q2p_triples, REL_RELATION),
}

for fn, (triples, relation) in fn_to_tripleNrel.items():
    with open(os.path.join(out_dir, fn), "w") as fout:
        for (hid, pos_tid, neg_tid) in triples:
            fout.write(f"{hid}\t{pos_tid}\t{neg_tid}\t{relation}\n")
            

In [16]:
# sanity check
in_dir="/home/jupyter/unity_jointly_rec_and_search/datasets/amazon_esci_dataset/task_2_multiclass_product_classification/"
out_dir = os.path.join(in_dir, "unified_train/")
for path in os.listdir(out_dir):
    if not path.endswith("scbnegs.train.tsv"):
        continue
    path = os.path.join(out_dir, path)
    ! wc -l $path
    ! head -n 5 $path
    ! tail -n 5 $path
    print("="*100)

600567 /home/jupyter/unity_jointly_rec_and_search/datasets/amazon_esci_dataset/task_2_multiclass_product_classification/unified_train/max15_qorient_q2p.scbnegs.train.tsv
1267636	517038	515292	is_relevant_to
1267636	1093650	517038	is_relevant_to
1267636	517040	1141802	is_relevant_to
1267636	517033	777499	is_relevant_to
1267636	517034	90103	is_relevant_to
1281425	973124	1195268	is_relevant_to
1281425	695445	1195266	is_relevant_to
1281425	971808	1195242	is_relevant_to
1281425	695448	775789	is_relevant_to
1281425	953046	775794	is_relevant_to
105906 /home/jupyter/unity_jointly_rec_and_search/datasets/amazon_esci_dataset/task_2_multiclass_product_classification/unified_train/max2_qorient_q2p.scbnegs.train.tsv
1267636	517030	515433	is_relevant_to
1267636	517039	428761	is_relevant_to
1228232	116172	116161	is_relevant_to
1228232	116163	856362	is_relevant_to
1275084	600152	489173	is_relevant_to
1279154	651903	997221	is_relevant_to
1224704	82034	661050	is_relevant_to
1224704	82036	5376	is_relevan

In [18]:
hid, pos_tid, neg_tid = (1224704,82036,5376)
print(eid_to_text[hid])
print(eid_to_text[pos_tid])
print(eid_to_text[neg_tid])

blonde wig bob
QDBOWIN QUEEN HAIR Brown Rooted Honey Blonde Ombre Lace Front Wigs for Women Flawless Wavy Bob Hair Heat Resistant Synthetic Wig Half Hand Tied 14 inch
Highlight Straight Bob Wigs Honey Blonde Ombre Short Wigs with Bangs 4/27 Mix Brown Color Unprocessed Brazilian Virgin Hair for Black Women Non Lace Front Human Hair (12 Inch)
