In [None]:
import os 

from google.cloud import bigquery
import pandas as pd 
import numpy as np

client = bigquery.Client()
print("Client creating using default project: {}".format(client.project))

query = """
    SELECT * 
    FROM `gcp-ushi-digital-ds-qa.new_hansi_dataset.comp_rec_ClicksData_2core`;
    """
query_job = client.query(query)
compl_rec_df = query_job.to_dataframe()

query = """
    SELECT * 
    FROM `gcp-ushi-digital-ds-qa.new_hansi_dataset.hansi_rec_ClicksData_5core`;
    """
query_job = client.query(query)
sim_rec_df = query_job.to_dataframe()

query = """
    SELECT *
    FROM `gcp-ushi-digital-ds-qa.new_hansi_dataset.search_ClicksData_1year_5core`;
"""
query_job = client.query(query)
search_df = query_job.to_dataframe()

query = """
    SELECT * 
    FROM `gcp-ushi-digital-ds-qa.hansi_dataset.all_products_info`;
    """
query_job = client.query(query)
product_df = query_job.to_dataframe()
print("product_df = {:,}".format(len(product_df)))

all_products = set(product_df.product_id)
anchors = set(compl_rec_df.anchor)
compl_ivms = set(compl_rec_df.ivm)
all_compl_ivms = anchors.union(compl_ivms)

print("================================ For anchor_to_compl_ivms: ===================================")
print("number of unique product = {:,}, anchors = {:,}, complementary_compl_ivms = {:,}".format(len(all_products), len(anchors), len(compl_ivms)))
assert len(all_products & anchors) == len(anchors) and len(all_products & compl_ivms) == len(compl_ivms),(
    len(all_products & anchors), len(anchors), len(all_products & compl_ivms), len(compl_ivms)
)

all_sim_ivms = set(sim_rec_df.anchor).union(set(sim_rec_df.ivm))
print("================================ After updating anchor_to_similar_ivms: ===================================")
print("all_compl_ivms = {:,}, all_sim_ivms = {:,}".format(len(all_compl_ivms), len(all_sim_ivms)))
print("sim_compl_intersect = {:,} ({:.3f})".format(len(all_compl_ivms & all_sim_ivms), len(all_compl_ivms & all_sim_ivms) / len(all_compl_ivms)))
print("all_ivms = {:,}".format(len(all_compl_ivms | all_sim_ivms)))
all_ivms = all_compl_ivms | all_sim_ivms

assert len(all_products & all_ivms) == len(all_ivms), (len(all_products & all_ivms), len(all_ivms))

query_to_ivms = search_df.groupby("query")["ivm"].apply(list)
ivm_to_queries = search_df.groupby("ivm")["query"].apply(list)
query_lengths = np.array([len(x) for x in ivm_to_queries.values])
all_queries = set(search_df["query"])
print("all queries = {:,}".format(len(all_queries)))
print("total ivms (queries) = {:,}, length >=3 = {:,}, length >= 5 = {:,}".format(
    len(query_lengths), np.sum(query_lengths >=3), np.sum(query_lengths >= 5) ))

anchor_to_compl_ivms = compl_rec_df.groupby("anchor")["ivm"].apply(list)
compl_ivms_length = np.array([len(x) for x in anchor_to_compl_ivms.values])
print("================================ For anchor_to_compl_ivms: ===================================")
print("total_compl_ivms = {:,}, length >=3 = {:,}, length >= 5 = {:,}".format(len(compl_ivms_length), np.sum(compl_ivms_length >=3), np.sum(compl_ivms_length >= 5) ))

anchor_to_sim_ivms = sim_rec_df.groupby("anchor")["ivm"].apply(list)


In [24]:
import pickle

in_dir = "/home/jupyter/jointly_rec_and_search/datasets/unified_kgc/"

with open(os.path.join(in_dir, "ivm_to_pid.pkl"), "rb") as fin:
    ivm_to_pid = pickle.load(fin)
with open(os.path.join(in_dir, "query_to_qid.pkl"), "rb") as fin:
    query_to_qid = pickle.load(fin)


pid_to_qids = {ivm_to_pid[ivm]: [query_to_qid[query] for query in queries] for ivm, queries in ivm_to_queries.items()}
aid_to_complpids = {ivm_to_pid[anchor]: [ivm_to_pid[prod] for prod in products] for anchor, products in anchor_to_compl_ivms.items()}
aid_to_simpids = {ivm_to_pid[anchor]: [ivm_to_pid[prod] for prod in products] for anchor, products in anchor_to_sim_ivms.items()}

In [None]:
# start create graph
import random 
from tqdm import tqdm
import networkx as nx

random.seed(4680)

val_test_aids = random.sample(aid_to_simpids.keys(), int(0.2*len(aid_to_simpids)))
val_aids = val_test_aids[:int(0.5*len(val_test_aids))]
test_aids = val_test_aids[int(0.5*len(val_test_aids)):]
train_aid_to_simpids, val_aid_to_simpids, test_aid_to_simpids = {}, {}, {}
for aid, simpids in tqdm(aid_to_simpids.items(), total=len(aid_to_simpids)):
    if aid in val_aids:
        val_aid_to_simpids[aid] = simpids
    elif aid in test_aids:
        test_aid_to_simpids[aid] = simpids
    else:
        train_aid_to_simpids[aid] = simpids

print("number of aid_to_simpids  train = {:,}, val = {:,}, test = {:,}".format(len(train_aid_to_simpids), 
                                                                              len(val_aid_to_simpids), len(test_aid_to_simpids)))
assert len( set(train_aid_to_simpids.keys()) & set(val_aid_to_simpids.keys()) & set(test_aid_to_simpids.keys()) ) == 0

G = nx.MultiDiGraph()
SIM_RELATION = "is_similar_to"
COMPL_RELATION = "is_complementary_to"
REL_RELATION = "is_relevant_to"

for aid, sim_pids in train_aid_to_simpids.items():
    triples = [(aid, sim_pid, {"type":SIM_RELATION}) for sim_pid in sim_pids]
    G.add_edges_from(triples)
    
for aid, compl_pids in aid_to_complpids.items():
    triples = [(aid, compl_pid, {"type":COMPL_RELATION}) for compl_pid in compl_pids]
    G.add_edges_from(triples)
    
for pid, qids in pid_to_qids.items():
    triples = [(pid, qid, {"type": REL_RELATION}) for qid in qids]
    G.add_edges_from(triples)

multi_edge_pairs = []
for n, nbrs_dict in G.adj.items():
    for nbr_node, edge_attrs in nbrs_dict.items():
        assert len(edge_attrs) == 1 or len(edge_attrs) == 2
        if len(edge_attrs) == 2:
            multi_edge_pairs.append((n, nbr_node))
            
print("number of edges = {:,}, number of multi-attr edges = {:,}, ({:.3f})".format(G.number_of_edges(), len(multi_edge_pairs), 
                                                                                   len(multi_edge_pairs)/G.number_of_edges()))

In [None]:
def create_triples(hid, pos_tid, miss_hids, duplicate_pairs, eid_to_text, sampler=None):
    if sampler != None:
        assert type(sampler) == dict, type(sampler)
        if hid not in sampler:
            miss_hids.append(hid)
            return 0
    if eid_to_text[hid] == eid_to_text[pos_tid]:
        duplicate_pairs.append((hid, pos_tid))
        return 0
    
    if sampler != None:
        neg_tid = random.sample(sampler[hid], k=1)[0]
        while neg_tid == pos_tid:
            neg_tid = random.sample(sampler[hid], k=1)[0]
    else:
        neg_tid = random.sample(range(2_000_000), k=1)[0]
        while neg_tid == pos_tid:
            neg_tid = random.sample(range(2_000_000), k=1)[0]
            
    return (hid, pos_tid, neg_tid)


eid_to_text = {}
with open(os.path.join(in_dir, "all_entities.tsv")) as fin:
    for line in fin:
        eid, text = line.strip().split("\t")
        eid_to_text[int(eid)] = text
        
run_path = os.path.join(in_dir, "runs/bm25.all.run")
df = pd.read_csv(run_path, sep=" ", names=["hid", "q0", "tid", "rank", "score", "model_name"])
bm25_hid_to_tids = {}
ignore_hids = set()
for hid, group in df.groupby("hid"):
    cand_tids = list(group.tid.values)
    if len(cand_tids) < 10:
        ignore_hids.add(int(hid))
    else:
        bm25_hid_to_tids[int(hid)] = [int(x) for x in cand_tids]
        
print("number of ignore hids = {}".format(len(ignore_hids)))

In [None]:
max5_h2sp = {}
max5_h2cp = {}
max5_h2q = {}

for head_node, nbrs_dict in tqdm(G.adj.items(), total=G.number_of_nodes()):
    sim_pids = []
    compl_pids = []
    rel_qids = []
    for tail_node, edge_attrs in nbrs_dict.items():
        assert len(edge_attrs) == 1 or len(edge_attrs) == 2
        relations = []
        for no, edge_attr in edge_attrs.items():
            relations.append(edge_attr["type"])
        for rel in relations:
            assert rel in [SIM_RELATION, COMPL_RELATION, REL_RELATION]
            if rel in SIM_RELATION:
                sim_pids.append(tail_node)
            if rel in COMPL_RELATION:
                compl_pids.append(tail_node)
            if rel in REL_RELATION:
                rel_qids.append(tail_node)
    if len(sim_pids) != 0:
        max5_h2sp[head_node] = random.sample(sim_pids, k=len(sim_pids))[:5]
    if len(compl_pids) != 0:
        max5_h2cp[head_node] = random.sample(compl_pids, k=len(compl_pids))[:5]
    if len(rel_qids) != 0:
        max5_h2q[head_node] = random.sample(rel_qids, k=len(rel_qids))[:5]
        
max5_s2sp = {}
max5_s2cp = {}
max5_s2q = {}
for _, head_nodes in tqdm(max5_h2sp.items(), total=len(max5_h2sp)):
    for head_node in head_nodes:
        sim_pids = []
        compl_pids = []
        rel_qids = []
        for tail_node, edge_attrs in G.adj[head_node].items():
            assert len(edge_attrs) == 1 or len(edge_attrs) == 2
            relations = []
            for no, edge_attr in edge_attrs.items():
                relations.append(edge_attr["type"])
            for rel in relations:
                if rel in SIM_RELATION:
                    sim_pids.append(tail_node)
                if rel in COMPL_RELATION:
                    compl_pids.append(tail_node)
                if rel in REL_RELATION:
                    rel_qids.append(tail_node)
        if len(sim_pids) != 0:
            max5_s2sp[head_node] = random.sample(sim_pids, k=len(sim_pids))[:5]
        if len(compl_pids) != 0:
            max5_s2cp[head_node] = random.sample(compl_pids, k=len(compl_pids))[:5]
        if len(rel_qids) != 0:
            max5_s2q[head_node] = random.sample(rel_qids, k=len(rel_qids))[:5]
        
miss_hids = []
duplicate_pairs = []

h2sp_triples = []
h2cp_triples = []
q2h_triples = []
for hid, tail_ids in max5_h2sp.items():
    for pos_tid in tail_ids:
        triple = create_triples(hid, pos_tid, miss_hids, duplicate_pairs, eid_to_text)
        if triple != 0:
            h2sp_triples.append(triple)
print("miss_hids = {:,}, duplicate_pairs = {:,}".format(len(miss_hids), len(duplicate_pairs)))
print("="*75)
for hid, tail_ids in max5_h2cp.items():
    for pos_tid in tail_ids:
        triple = create_triples(hid, pos_tid, miss_hids, duplicate_pairs, eid_to_text, sampler=bm25_hid_to_tids)
        if triple != 0:
            h2cp_triples.append(triple)
print("miss_hids = {:,}, duplicate_pairs = {:,}".format(len(miss_hids), len(duplicate_pairs)))
print("="*75)
for pos_tid, head_ids in max5_h2q.items():
    for hid in head_ids:
        triple = create_triples(hid, pos_tid, miss_hids, duplicate_pairs, eid_to_text, sampler=bm25_hid_to_tids)
        if triple != 0:
            q2h_triples.append(triple)
print("miss_hids = {:,}, duplicate_pairs = {:,}".format(len(miss_hids), len(duplicate_pairs)))
print("="*75)

s2sp_triples = []
s2cp_triples = []
q2s_triples = []
for hid, tail_ids in max5_s2sp.items():
    for pos_tid in tail_ids:
        triple = create_triples(hid, pos_tid, miss_hids, duplicate_pairs, eid_to_text)
        if triple != 0:
            s2sp_triples.append(triple)
print("miss_hids = {:,}, duplicate_pairs = {:,}".format(len(miss_hids), len(duplicate_pairs)))
print("="*75)
for hid, tail_ids in max5_s2cp.items():
    for pos_tid in tail_ids:
        triple = create_triples(hid, pos_tid, miss_hids, duplicate_pairs, eid_to_text, sampler=bm25_hid_to_tids)
        if triple != 0:
            s2cp_triples.append(triple)
print("miss_hids = {:,}, duplicate_pairs = {:,}".format(len(miss_hids), len(duplicate_pairs)))
print("="*75)
for pos_tid, head_ids in max5_s2q.items():
    for hid in head_ids:
        triple = create_triples(hid, pos_tid, miss_hids, duplicate_pairs, eid_to_text, sampler=bm25_hid_to_tids)
        if triple != 0:
            q2s_triples.append(triple)
print("miss_hids = {:,}, duplicate_pairs = {:,}".format(len(miss_hids), len(duplicate_pairs)))
print("="*75)

In [65]:
out_dir = os.path.join(in_dir, "sim_train/")
if not os.path.exists(out_dir):
    os.mkdir(out_dir)

fn_to_tripleNrel = {
    "a2sp.train.tsv": (h2sp_triples, SIM_RELATION),
    "a2cp.train.tsv": (h2cp_triples, COMPL_RELATION),
    "q2a.train.tsv": (q2h_triples, REL_RELATION),
    "s2sp.train.tsv": (s2sp_triples, SIM_RELATION),
    "s2cp.train.tsv": (s2cp_triples, COMPL_RELATION),
    "q2s.train.tsv": (q2s_triples, REL_RELATION)
}

for fn, (triples, relation) in fn_to_tripleNrel.items():
    with open(os.path.join(out_dir, fn), "w") as fout:
        for (hid, pos_tid, neg_tid) in triples:
            fout.write(f"{hid}\t{pos_tid}\t{neg_tid}\t{relation}\n")
            
out_dir = os.path.join(in_dir, "sim_test/")
if not os.path.exists(out_dir):
    os.mkdir(out_dir)

fn_to_aids = {
    "anchors.train.tsv": list(train_aid_to_simpids.keys()),
    "anchors.val.tsv": list(val_aid_to_simpids.keys()),
    "anchors.test.tsv": list(test_aid_to_simpids.keys())
}
for fn, aids in fn_to_aids.items():
    with open(os.path.join(out_dir, fn), "w") as fout:
        for aid in aids:
            text = eid_to_text[aid]
            fout.write(f"{aid}\t{text}\t{SIM_RELATION}\n")
            
fn_to_arels = {
    "arels.train.tsv": [(aid, pid) for aid, simpids in train_aid_to_simpids.items() for pid in simpids],
    "arels.val.tsv": [(aid, pid) for aid, simpids in val_aid_to_simpids.items() for pid in simpids],
    "arels.test.tsv": [(aid, pid) for aid, simpids in test_aid_to_simpids.items() for pid in simpids],
}
for fn, arels in fn_to_arels.items():
    with open(os.path.join(out_dir, fn), "w") as fout:
        for (aid, pid) in arels:
            fout.write(f"{aid}\tQ0\t{pid}\t{1}\n")

In [None]:
# sanity check
out_dir = os.path.join(in_dir, "sim_train/")
for path in os.listdir(out_dir):
    path = os.path.join(out_dir, path)
    ! wc -l $path
    ! head -n 3 $path
    ! tail -n 3 $path
    print("="*100)



In [80]:
hid, pos_tid, neg_tid = (3061845,2233677,2054671)

! grep -P "^{hid}\t" "/home/jupyter/jointly_rec_and_search/datasets/unified_kgc/all_entities.tsv"
! grep -P "^{pos_tid}\t" "/home/jupyter/jointly_rec_and_search/datasets/unified_kgc/all_entities.tsv"
! grep -P "^{neg_tid}\t" "/home/jupyter/jointly_rec_and_search/datasets/unified_kgc/all_entities.tsv"


3061845	speakman shower fixtures
2233677	Speakman Chelsea Polished Chrome Dual Shower Head 2.5-GPM (9.5-LPM) ; Shower Heads
2054671	Speakman Brushed Nickel Shower Hose ; Bathroom & Shower Faucet Accessories


In [45]:
out_dir = os.path.join(in_dir, "sim_test/")
for path in os.listdir(out_dir):
    path = os.path.join(out_dir, path)
    ! wc -l $path
    ! head -n 3 $path
    ! tail -n 3 $path
    print("="*100)

[(1048567, 331101, 1150236),
 (1048567, 1144496, 1481931),
 (1144496, 789117, 792800),
 (1144496, 1777815, 1648864),
 (1363017, 2098989, 1720379),
 (1363017, 524725, 605724),
 (1363017, 731516, 963135),
 (1363017, 745636, 564390),
 (1363017, 935334, 1226274),
 (731516, 2098989, 1257665)]

In [83]:
len(max5_h2q), len(max5_h2sp)

(360744, 172991)