In [1]:
import os
import pickle as pkl
from collections import defaultdict
import random 

from tqdm import tqdm
import networkx as nx
import pandas as pd
random.seed(4680)

in_dir = "/home/jupyter/unity_jointly_rec_and_search/datasets/amazon_esci_dataset/task_2_multiclass_product_classification/"
datas = []
         
fns = [
    "train_aid_to_simpids.pkl",
    "train_aid_to_complpids.pkl",
    "train_qid_to_relpids.pkl",
    ]
for fn in fns:
    with open(os.path.join(in_dir, fn), "rb") as fin:
        datas.append(pkl.load(fin))

train_aid_to_simpids, train_aid_to_complpids, train_qid_to_pids = datas


max5_a2s = {}
max5_a2c = {}
max10_q2p = {}

for aid, simpids in train_aid_to_simpids.items():
    max5_a2s[aid] = random.sample(simpids, k=len(simpids))[:5]
for aid, complpids in train_aid_to_complpids.items():
    max5_a2c[aid] = random.sample(complpids, k=len(complpids))[:5]
for qid, relpids in train_qid_to_pids.items():
    max10_q2p[qid] = random.sample(relpids, k=len(relpids))[:10]
    


In [2]:
def create_self_triples(hid, pos_tid, hid_to_postids, sampler):
    neg_tid = random.sample(sampler[hid], k=1)[0]
    while neg_tid in hid_to_postids[hid]:
        neg_tid = random.sample(sampler[hid], k=1)[0]
    return (hid, pos_tid, neg_tid)

eid_to_text = {}
with open(os.path.join(in_dir, "all_entities.tsv")) as fin:
    for line in fin:
        eid, text = line.strip().split("\t")
        eid_to_text[int(eid)] = text
        
exp_dir = "/home/jupyter/unity_jointly_rec_and_search/experiments/amazon_esci/task2/phase_2/experiment_09-20_092917"
run_path = os.path.join(exp_dir, "runs/checkpoint_latest.all.run")
df = pd.read_csv(run_path, sep="\t", names=["hid", "tid", "rank", "score"])
self_hid_to_tids = {}
number_of_group = len(df.hid.unique())
for hid, group in tqdm(df.groupby("hid"), total=number_of_group):
    cand_tids = list(group.tid.values)
    assert len(cand_tids) == 200
    self_hid_to_tids[int(hid)] = [int(x) for x in cand_tids]
        

100%|██████████| 1284209/1284209 [03:03<00:00, 6986.86it/s]


In [3]:
a2s_triples = []
a2c_triples = []
q2p_triples = []

for aid, simpids in max5_a2s.items():
    for pos_pid in simpids:
        triple = create_self_triples(aid, pos_pid, train_aid_to_simpids, self_hid_to_tids)
        a2s_triples.append(triple)
        
for aid, complpids in max5_a2c.items():
    for pos_pid in complpids:
        triple = create_self_triples(aid, pos_pid, train_aid_to_complpids, self_hid_to_tids)
        a2c_triples.append(triple)
        
for qid, relpids in max10_q2p.items():
    for pos_pid in relpids:
        triple = create_self_triples(qid, pos_pid, train_qid_to_pids, self_hid_to_tids)
        q2p_triples.append(triple)

SIM_RELATION = "is_similar_to"
COMPL_RELATION = "is_complementary_to"
REL_RELATION = "is_relevant_to"

out_dir = os.path.join(exp_dir, "self_train")
if not os.path.exists(out_dir):
    os.mkdir(out_dir)
    
fn_to_data = {
    "a2s.train.tsv": (a2s_triples, SIM_RELATION),
    "a2c.train.tsv": (a2c_triples, COMPL_RELATION),
    "q2p.train.tsv": (q2p_triples, REL_RELATION),
    
    "a2s.50.train.tsv": (random.sample(a2s_triples, k=int(0.5*len(a2s_triples))), SIM_RELATION),
}

for fn, (triples, relation) in fn_to_data.items():
    fn = os.path.join(out_dir, fn)
    with open(fn, "w") as fout:
        for triple in triples:
            hid, pos_tid, neg_tid = triple
            fout.write(f"{hid}\t{pos_tid}\t{neg_tid}\t{relation}\n")


In [4]:
# sanity check
for path in os.listdir(out_dir):
    path = os.path.join(out_dir, path)
    ! wc -l $path
    ! head -n 3 $path
    ! tail -n 3 $path
    print("="*100)

454902 /home/jupyter/unity_jointly_rec_and_search/experiments/amazon_esci/task2/phase_2/experiment_09-20_092917/self_train/q2p.train.tsv
1267636	517030	1025121	is_relevant_to
1267636	517041	876388	is_relevant_to
1267636	892198	694296	is_relevant_to
1281425	1195273	692297	is_relevant_to
1281425	971808	687394	is_relevant_to
1281425	775780	1033410	is_relevant_to
1068035 /home/jupyter/unity_jointly_rec_and_search/experiments/amazon_esci/task2/phase_2/experiment_09-20_092917/self_train/a2s.train.tsv
881159	19	953813	is_similar_to
881159	953826	680821	is_similar_to
881159	881221	173547	is_similar_to
766742	766747	998232	is_similar_to
766742	766758	192201	is_similar_to
766742	766722	1008598	is_similar_to
184468 /home/jupyter/unity_jointly_rec_and_search/experiments/amazon_esci/task2/phase_2/experiment_09-20_092917/self_train/a2c.train.tsv
881159	1032053	500117	is_complementary_to
881159	1033049	655386	is_complementary_to
881159	953818	1017029	is_complementary_to
766746	766736	893841	is_comple

In [5]:
hid, pos_tid, neg_tid = (881159,953818,189332)
print(eid_to_text[hid])
print(eid_to_text[pos_tid])
print(eid_to_text[neg_tid])

Dixon No. 2 Yellow Pencils, Wood-Cased, Black Core, #2 HB Soft, 12-Count (14402)
Arteza HB Pencils #2, Pack of 48, Wood-Cased Graphite Pencils in Bulk, Pre-Sharpened, with Latex-Free Erasers, Office & School Supplies for Exams and Classrooms
June Gold 8 Pack 0.5 mm HB #2 Mechanical Pencils, Extra Long Spin Eraser, 2 Lead Dispensers/w 220 Refills & 8 Refill Erasers, Break Resistant Lead, Soft Non-Slip Grip
