In [1]:
import os
import pickle
import glob

import pandas as pd 
import numpy as np

bq_in_dir="/home/jupyter/data_transfer/data/"
sim_rec_df = pd.read_csv(os.path.join(bq_in_dir, "hansi_rec_ClicksData_5core.csv"), index_col=0)
compl_rec_df = pd.read_csv(os.path.join(bq_in_dir, "comp_rec_ClicksData_2core.csv"), index_col=0)
search_df = pd.read_csv(os.path.join(bq_in_dir, "search_ClicksData_1year_5core.csv"), index_col=0)
product_df = pd.read_csv(os.path.join(bq_in_dir, "all_products_info.csv"), index_col=0)

all_products = set(product_df.product_id)
anchors = set(compl_rec_df.anchor)
compl_ivms = set(compl_rec_df.ivm)
all_compl_ivms = anchors.union(compl_ivms)

print("================================ For anchor_to_compl_ivms: ===================================")
print("number of unique product = {:,}, anchors = {:,}, complementary_compl_ivms = {:,}".format(len(all_products), len(anchors), len(compl_ivms)))
assert len(all_products & anchors) == len(anchors) and len(all_products & compl_ivms) == len(compl_ivms),(
    len(all_products & anchors), len(anchors), len(all_products & compl_ivms), len(compl_ivms)
)

all_sim_ivms = set(sim_rec_df.anchor).union(set(sim_rec_df.ivm))
print("================================ After updating anchor_to_similar_ivms: ===================================")
print("all_compl_ivms = {:,}, all_sim_ivms = {:,}".format(len(all_compl_ivms), len(all_sim_ivms)))
print("sim_compl_intersect = {:,} ({:.3f})".format(len(all_compl_ivms & all_sim_ivms), len(all_compl_ivms & all_sim_ivms) / len(all_compl_ivms)))
print("all_ivms = {:,}".format(len(all_compl_ivms | all_sim_ivms)))
all_ivms = all_compl_ivms | all_sim_ivms

assert len(all_products & all_ivms) == len(all_ivms), (len(all_products & all_ivms), len(all_ivms))

query_to_ivms = search_df.groupby("query")["ivm"].apply(list)
ivm_to_queries = search_df.groupby("ivm")["query"].apply(list)
query_lengths = np.array([len(x) for x in ivm_to_queries.values])
all_queries = set(search_df["query"])
print("all queries = {:,}".format(len(all_queries)))
assert len(all_queries) == len(query_to_ivms), len(query_to_ivms)
print("total ivms (queries) = {:,}, length >=3 = {:,}, length >= 5 = {:,}".format(
    len(query_lengths), np.sum(query_lengths >=3), np.sum(query_lengths >= 5) ))

anchor_to_compl_ivms = compl_rec_df.groupby("anchor")["ivm"].apply(list)
compl_ivms_length = np.array([len(x) for x in anchor_to_compl_ivms.values])
print("================================ For anchor_to_compl_ivms: ===================================")
print("total_compl_ivms = {:,}, length >=3 = {:,}, length >= 5 = {:,}".format(len(compl_ivms_length), np.sum(compl_ivms_length >=3), np.sum(compl_ivms_length >= 5) ))

anchor_to_sim_ivms = sim_rec_df.groupby("anchor")["ivm"].apply(list)

  mask |= (ar1 == a)


number of unique product = 2,260,878, anchors = 86,870, complementary_compl_ivms = 65,561
all_compl_ivms = 109,758, all_sim_ivms = 256,765
sim_compl_intersect = 87,425 (0.797)
all_ivms = 279,098
all queries = 953,773
total ivms (queries) = 360,744, length >=3 = 196,481, length >= 5 = 142,527
total_compl_ivms = 86,870, length >=3 = 35,837, length >= 5 = 22,121


In [2]:
import pickle

user_dir = "/home/jupyter/unity_jointly_rec_and_search/datasets/unified_user/"

with open(os.path.join(user_dir, "ivm_to_pid.pkl"), "rb") as fin:
    ivm_to_pid = pickle.load(fin)
with open(os.path.join(user_dir, "query_to_qid.pkl"), "rb") as fin:
    query_to_qid = pickle.load(fin)


pid_to_qids = {ivm_to_pid[ivm]: [query_to_qid[query] for query in queries] for ivm, queries in ivm_to_queries.items()}
qid_to_pids = {query_to_qid[query]: [ivm_to_pid[ivm] for ivm in ivms] for query, ivms in query_to_ivms.items()}
aid_to_complpids = {ivm_to_pid[anchor]: [ivm_to_pid[prod] for prod in products] for anchor, products in anchor_to_compl_ivms.items()}
aid_to_simpids = {ivm_to_pid[anchor]: [ivm_to_pid[prod] for prod in products] for anchor, products in anchor_to_sim_ivms.items()}

In [3]:
# start create graph

import random 
import pickle as pkl
from tqdm import tqdm
import networkx as nx
from collections import defaultdict
random.seed(4680)

test_dir = os.path.join(user_dir, "zero_shot")

with open(os.path.join(test_dir, "exclude_aid_to_simpids.pkl"), "rb") as fin:
    exclude_aid_to_simpids = pkl.load(fin)
with open(os.path.join(test_dir, "exclude_aid_to_complpids.pkl"), "rb") as fin:
    exclude_aid_to_complpids = pkl.load(fin)
with open(os.path.join(test_dir, "exclude_qid_to_relpids.pkl"), "rb") as fin:
    exclude_qid_to_relpids = pkl.load(fin)

train_aid_to_simpids = {}
for aid, simpids in tqdm(aid_to_simpids.items(), total=len(aid_to_simpids)):
    if aid in exclude_aid_to_simpids:
        keep_simpids = list(set(simpids).difference(set(exclude_aid_to_simpids[aid])))
    else:
        keep_simpids = simpids
    train_aid_to_simpids[aid] = keep_simpids

train_aid_to_complpids = {}
for aid, complpids in tqdm(aid_to_complpids.items(), total=len(aid_to_complpids)):
    if aid in exclude_aid_to_complpids:
        keep_complpids = list(set(complpids).difference(set(exclude_aid_to_complpids[aid])))
    else:
        keep_complpids = complpids
    train_aid_to_complpids[aid] = keep_complpids
        
train_qid_to_pids = {}
for qid, pids in tqdm(qid_to_pids.items(), total=len(qid_to_pids)):
    if qid in exclude_qid_to_relpids:
        keep_relpids = list(set(pids).difference(exclude_qid_to_relpids[qid]))
    else:
        keep_relpids = pids 
    train_qid_to_pids[qid] = keep_relpids
        

print("number of arels for sim_rec before exclusion and after exclusion = {:,}, {:,}".format(
    sum([len(xs) for xs in aid_to_simpids.values()]), sum([len(xs) for xs in train_aid_to_simpids.values()])
))
print("number of arels for compl_rec before exclusion and after exclusion = {:,}, {:,}".format(
    sum([len(xs) for xs in aid_to_complpids.values()]), sum([len(xs) for xs in train_aid_to_complpids.values()])
))
print("number of qrels for search before exclusion and after exclusion = {:,}, {:,}".format(
    sum([len(xs) for xs in qid_to_pids.values()]), sum([len(xs) for xs in train_qid_to_pids.values()])
))

G = nx.MultiDiGraph()
SIM_RELATION = "is_similar_to"
COMPL_RELATION = "is_complementary_to"
REL_RELATION = "is_relevant_to"

for aid, sim_pids in train_aid_to_simpids.items():
    triples = [(aid, sim_pid, {"type":SIM_RELATION}) for sim_pid in sim_pids]
    G.add_edges_from(triples)
    
for aid, compl_pids in train_aid_to_complpids.items():
    triples = [(aid, compl_pid, {"type":COMPL_RELATION}) for compl_pid in compl_pids]
    G.add_edges_from(triples)
    
for qid, pids in train_qid_to_pids.items():
    triples = [(pid, qid, {"type": REL_RELATION}) for pid in pids]
    G.add_edges_from(triples)
    
multi_edge_pairs = []
for n, nbrs_dict in tqdm(G.adj.items(), total=G.number_of_nodes()):
    for nbr_node, edge_attrs in nbrs_dict.items():
        assert len(edge_attrs) == 1 or len(edge_attrs) == 2
        if len(edge_attrs) == 2:
            multi_edge_pairs.append((n, nbr_node))
            
print("number of edges = {:,}, number of multi-attr edges = {:,}, ({:.3f})".format(G.number_of_edges(), len(multi_edge_pairs), 
                                                                                   len(multi_edge_pairs)/G.number_of_edges()))

100%|██████████| 216238/216238 [00:00<00:00, 772479.93it/s]
100%|██████████| 86870/86870 [00:00<00:00, 1004893.17it/s]
100%|██████████| 953773/953773 [00:01<00:00, 894635.16it/s]


number of arels for sim_rec before exclusion and after exclusion = 968,778, 944,349
number of arels for compl_rec before exclusion and after exclusion = 329,992, 324,746
number of qrels for search before exclusion and after exclusion = 4,075,996, 4,056,138


100%|██████████| 1375625/1375625 [00:09<00:00, 142695.25it/s]


number of edges = 5,325,233, number of multi-attr edges = 17,830, (0.003)


In [4]:
def create_triples(hid, pos_tid, miss_hids, duplicate_pairs, eid_to_text, sampler=None):
    if sampler != None:
        assert type(sampler) == dict, type(sampler)
        if hid not in sampler:
            miss_hids.append(hid)
            return 0
    if eid_to_text[hid] == eid_to_text[pos_tid]:
        duplicate_pairs.append((hid, pos_tid))
        return 0
    
    if sampler != None:
        neg_tid = random.sample(sampler[hid], k=1)[0]
        while neg_tid == pos_tid:
            neg_tid = random.sample(sampler[hid], k=1)[0]
    else:
        neg_tid = random.sample(range(2_000_000), k=1)[0]
        while neg_tid == pos_tid:
            neg_tid = random.sample(range(2_000_000), k=1)[0]
            
    return (hid, pos_tid, neg_tid)


eid_to_text = {}
with open(os.path.join(user_dir, "all_entities.tsv")) as fin:
    for line in fin:
        eid, text = line.strip().split("\t")
        eid_to_text[int(eid)] = text
        
run_path = os.path.join(user_dir, "runs/bm25.all.run")
df = pd.read_csv(run_path, sep=" ", names=["hid", "q0", "tid", "rank", "score", "model_name"])
bm25_hid_to_tids = {}
ignore_hids = set()
for hid, group in df.groupby("hid"):
    cand_tids = list(group.tid.values)
    if len(cand_tids) < 10:
        ignore_hids.add(int(hid))
    else:
        bm25_hid_to_tids[int(hid)] = [int(x) for x in cand_tids]
        
print("number of ignore hids = {}".format(len(ignore_hids)))

number of ignore hids = 6644


In [5]:
max5_h2sp = {}
max5_h2cp = {}
max5_h2q = {}

for head_node, nbrs_dict in tqdm(G.adj.items(), total=G.number_of_nodes()):
    sim_pids = []
    compl_pids = []
    rel_qids = []
    for tail_node, edge_attrs in nbrs_dict.items():
        assert len(edge_attrs) == 1 or len(edge_attrs) == 2
        relations = []
        for no, edge_attr in edge_attrs.items():
            relations.append(edge_attr["type"])
        for rel in relations:
            assert rel in [SIM_RELATION, COMPL_RELATION, REL_RELATION]
            if rel in SIM_RELATION:
                sim_pids.append(tail_node)
            if rel in COMPL_RELATION:
                compl_pids.append(tail_node)
            if rel in REL_RELATION:
                rel_qids.append(tail_node)
    if len(sim_pids) != 0:
        max5_h2sp[head_node] = random.sample(sim_pids, k=len(sim_pids))[:5]
    if len(compl_pids) != 0:
        max5_h2cp[head_node] = random.sample(compl_pids, k=len(compl_pids))[:5]
    if len(rel_qids) != 0:
        max5_h2q[head_node] = random.sample(rel_qids, k=len(rel_qids))[:5]
        
miss_hids = []
duplicate_pairs = []

h2sp_triples = []
h2cp_triples = []
q2h_triples = []
for hid, tail_ids in max5_h2sp.items():
    for pos_tid in tail_ids:
        triple = create_triples(hid, pos_tid, miss_hids, duplicate_pairs, eid_to_text)
        if triple != 0:
            h2sp_triples.append(triple)
print("miss_hids = {:,}, duplicate_pairs = {:,}".format(len(miss_hids), len(duplicate_pairs)))
print("="*75)
for hid, tail_ids in max5_h2cp.items():
    for pos_tid in tail_ids:
        triple = create_triples(hid, pos_tid, miss_hids, duplicate_pairs, eid_to_text, sampler=bm25_hid_to_tids)
        if triple != 0:
            h2cp_triples.append(triple)
print("miss_hids = {:,}, duplicate_pairs = {:,}".format(len(miss_hids), len(duplicate_pairs)))
print("="*75)
for pos_tid, head_ids in max5_h2q.items():
    for hid in head_ids:
        triple = create_triples(hid, pos_tid, miss_hids, duplicate_pairs, eid_to_text, sampler=bm25_hid_to_tids)
        if triple != 0:
            q2h_triples.append(triple)
print("miss_hids = {:,}, duplicate_pairs = {:,}".format(len(miss_hids), len(duplicate_pairs)))
print("="*75)

100%|██████████| 1375625/1375625 [00:26<00:00, 51996.84it/s] 


miss_hids = 0, duplicate_pairs = 33,036
miss_hids = 0, duplicate_pairs = 33,423
miss_hids = 38,299, duplicate_pairs = 33,423


In [6]:
max5_q2p = {}
max2_q2p = {}
max5_q2p_triples = []
max2_q2p_triples = []
for qid, pids in train_qid_to_pids.items():
    max5_q2p[qid] = random.sample(pids, k=len(pids))[:5]
    max2_q2p[qid] = random.sample(pids, k=len(pids))[:2]
    
for qid, pos_pids in max5_q2p.items():
    for pos_pid in pos_pids:
        triple = create_triples(qid, pos_pid, miss_hids, duplicate_pairs, eid_to_text, sampler=bm25_hid_to_tids)
        if triple != 0:
            max5_q2p_triples.append(triple)
print("miss_hids = {:,}, duplicate_pairs = {:,}".format(len(miss_hids), len(duplicate_pairs)))
print("="*75)

for qid, pos_pids in max2_q2p.items():
    for pos_pid in pos_pids:
        triple = create_triples(qid, pos_pid, miss_hids, duplicate_pairs, eid_to_text, sampler=bm25_hid_to_tids)
        if triple != 0:
            max2_q2p_triples.append(triple)
print("miss_hids = {:,}, duplicate_pairs = {:,}".format(len(miss_hids), len(duplicate_pairs)))
print("="*75)
print(len(h2sp_triples), len(h2cp_triples), len(q2h_triples), len(max2_q2p_triples))

miss_hids = 132,829, duplicate_pairs = 33,423
miss_hids = 210,529, duplicate_pairs = 33,423
593676 220656 1072097 1373732


In [11]:
import pickle
import ujson

# downsample max2_q2p_triples
keep_q2p_num = 1_200_000
all_max2_q2p_triples = max2_q2p_triples
max2_q2p_triples = random.sample(all_max2_q2p_triples, k=keep_q2p_num)

kgc_train_dir = os.path.join(test_dir, "unified_kgc_train")
if not os.path.exists(kgc_train_dir):
    os.mkdir(kgc_train_dir)

fn_to_tripleNrel = {
    "a2sp.train.tsv": (h2sp_triples, SIM_RELATION),
    "a2cp.train.tsv": (h2cp_triples, COMPL_RELATION),
    "max2_qorient_q2p.train.tsv": (max2_q2p_triples, REL_RELATION),
}

for fn, (triples, relation) in fn_to_tripleNrel.items():
    with open(os.path.join(kgc_train_dir, fn), "w") as fout:
        for (hid, pos_tid, neg_tid) in triples:
            fout.write(f"{hid}\t{pos_tid}\t{neg_tid}\t{relation}\n")
            

# create test dir 
kgc_test_dir = os.path.join(test_dir, "unified_kgc_test_without_ctx")
if not os.path.exists(kgc_test_dir):
    os.mkdir(kgc_test_dir)

fin_to_fout = {
    os.path.join(test_dir, "sequential_train_test_hlen_4_bm25/sim_rec_sequential.test.json"): os.path.join(kgc_test_dir, "tid_anchors.sim.test.tsv"),
    os.path.join(test_dir, "sequential_train_test_hlen_4_bm25/compl_rec_sequential.test.json"): os.path.join(kgc_test_dir, "tid_anchors.compl.test.tsv"),
    os.path.join(test_dir, "sequential_train_test_hlen_4_bm25/search_sequential.test.json"): os.path.join(kgc_test_dir, "tid_queries.search.test.tsv")
}
for in_fn, out_fn in fin_to_fout.items():
    with open(in_fn) as fin:
        with open(out_fn, "w") as fout:
            for line in fin:
                example = ujson.loads(line)
                tid, qid, relation = example["test_id"], example["query_ids"][-1], example["relation"]
                fout.write(f"{tid}\t{eid_to_text[qid]}\t{relation}\n")


In [12]:
# sanity check
for path in os.listdir(kgc_train_dir):
    path = os.path.join(kgc_train_dir, path)
    if "train_graph.pkl" in path:
        continue
    ! wc -l $path
    ! head -n 3 $path
    ! tail -n 3 $path
    print("="*100)

593676 /home/jupyter/unity_jointly_rec_and_search/datasets/unified_user/zero_shot/unified_kgc_train/a2sp.train.tsv
1048567	331101	1389878	is_similar_to
1048567	1144496	1386487	is_similar_to
1144496	789117	1172136	is_similar_to
1891760	2233677	1194885	is_similar_to
352886	1068443	1463456	is_similar_to
615570	1017607	433181	is_similar_to
220656 /home/jupyter/unity_jointly_rec_and_search/datasets/unified_user/zero_shot/unified_kgc_train/a2cp.train.tsv
331101	1221769	813160	is_complementary_to
331101	1937825	1861909	is_complementary_to
331101	1503775	1144496	is_complementary_to
364149	41477	559183	is_complementary_to
187314	1236902	2085856	is_complementary_to
1864309	2154277	396790	is_complementary_to
1200000 /home/jupyter/unity_jointly_rec_and_search/datasets/unified_user/zero_shot/unified_kgc_train/max2_qorient_q2p.train.tsv
2435046	1531376	2142901	is_relevant_to
2773273	1314329	765662	is_relevant_to
2660965	1413207	417494	is_relevant_to
2470627	8625	1022409	is_relevant_to
2283652	158708

In [15]:
hid, pos_tid, neg_tid = (1048567,331101,1389878)
print(hid, eid_to_text[hid])
print(pos_tid, eid_to_text[pos_tid])
print(neg_tid, eid_to_text[neg_tid])

1048567 ReliaBilt 11/16-in x 12-in x 6-ft Western Red Cedar Dog Ear Fence Picket ; Wood Fence Pickets
331101  5/8-in x 5-1/2-in x 6-ft Cedar Dog Ear Fence Picket ; Wood Fence Pickets
1389878 Caroline's Treasures 14-in W x 21-in L x 0.2-in H Fabric Drying Mat ; Dish Racks & Trays


In [21]:
# sanity check
for path in os.listdir(kgc_test_dir):
    path = os.path.join(kgc_test_dir, path)
    ! wc -l $path
    ! head -n 5 $path
    ! tail -n 5 $path
    print("="*100)

3612 /home/jupyter/unity_jointly_rec_and_search/datasets/unified_user/zero_shot/unified_kgc_test_without_ctx/tid_anchors.compl.test.tsv
0	Mr. Heater Buddy heaters 9000-BTU Outdoor Portable Radiant Propane Heater ; Propane Heaters	is_complementary_to
1	Pit Boss Meat Probe 2-Pack Stainless Steel Accessory Kit ; Grilling Tools & Utensils	is_complementary_to
2	Pit Boss Pit Boss Pro Series 4 Series Vertical Smoker ; Pellet Smokers	is_complementary_to
3	Reflectix 2-in x 30-ft Reflective Insulation Tape ; Reflective Insulation Tape	is_complementary_to
4	 R-, 0.75-in x 4-ft x 8-ft Expanded Polystyrene Board Insulation ; Board Insulation	is_complementary_to
3607	Quickie BULLDOZER 9-in Poly Fiber Stiff Deck Brush ; Deck Brushes	is_complementary_to
3608	Seal-Krete Interior/Exterior Concentrated Cleaner and Degreaser (1-Gallon) ; Concrete Preparation	is_complementary_to
3609	Amerimax Contemporary 4-in x 6-in White Half Round Gutter Drop Outlet ; Gutters	is_complementary_to
3610	Hillman 10 x 2-1/2-

In [20]:
eid_to_text[2216303]

'Reflectix 2-in x 30-ft Reflective Insulation Tape ; Reflective Insulation Tape'