In [12]:
import os 

import pandas as pd
import numpy as np 
from tqdm import tqdm

in_dir = "/home/jupyter/unity_jointly_rec_and_search/datasets/unified_user/"

datas = []
fns = [
    os.path.join(in_dir, "train_sim_recs.csv"),
    os.path.join(in_dir, "test_sim_recs.csv"),
    os.path.join(in_dir, "train_compl_recs.csv"),
    os.path.join(in_dir, "test_compl_recs.csv"),
    os.path.join(in_dir, "train_searchs.csv"),
    os.path.join(in_dir, "test_searchs.csv"),
]

for fn in fns:
    datas.append(pd.read_csv(fn, index_col=0))
    
train_sim_data, test_sim_data, train_compl_data, test_compl_data, train_search_data, test_search_data = datas
sim_data = pd.concat([train_sim_data, test_sim_data])
compl_data = pd.concat([train_compl_data, test_compl_data])
search_data = pd.concat([train_search_data, test_search_data])
assert len(sim_data) == len(train_sim_data) + len(test_sim_data) 
assert len(compl_data) == len(train_compl_data) + len(test_compl_data) 
assert len(search_data) == len(train_search_data) + len(test_search_data)

  mask |= (ar1 == a)


In [47]:
import random
from collections import defaultdict
import pickle as pkl
import pandas as pd
random.seed(4680)

out_dir = os.path.join(in_dir, "zero_shot")
if not os.path.exists(out_dir):
    os.mkdir(out_dir)

max_user_num = 10_000
selected_sim_users = random.sample(list(sim_data.uid), k=min(max_user_num, int(len(sim_data.uid.unique())*0.1)))
selected_compl_users = random.sample(list(compl_data.uid), k=min(max_user_num, int(len(compl_data.uid.unique())*0.1)))
selected_search_users = random.sample(list(search_data.uid), k=min(max_user_num, int(len(search_data.uid.unique())*0.1)))

print("number of selected sim, compl, search users {:,}, {:,}, {:,}".format(len(selected_sim_users), len(selected_compl_users), 
                                                                            len(selected_search_users)))

test_sim_data = sim_data[np.in1d(sim_data.uid, selected_sim_users)]
test_compl_data = compl_data[np.in1d(compl_data.uid, selected_compl_users)]
test_search_data = search_data[np.in1d(search_data.uid, selected_search_users)]

train_sim_data = sim_data[~np.in1d(sim_data.uid, selected_sim_users)]
train_compl_data = compl_data[~np.in1d(compl_data.uid, selected_compl_users)]
train_search_data = search_data[~np.in1d(search_data.uid, selected_search_users)]

assert len(test_sim_data) + len(train_sim_data) == len(sim_data) and len(test_compl_data) + len(train_compl_data) == len(compl_data)
assert len(test_search_data) + len(train_search_data) == len(search_data)

train_aid_to_simpids, train_aid_to_complpids, train_qid_to_pids = defaultdict(set), defaultdict(set), defaultdict(set)
test_aid_to_simpids, test_aid_to_complpids, test_qid_to_pids = defaultdict(set), defaultdict(set), defaultdict(set)
for aid, simpids in zip(train_sim_data.aid, train_sim_data.sim_pids):
    train_aid_to_simpids[aid].update(eval(simpids))
for aid, simpids in zip(test_sim_data.aid, test_sim_data.sim_pids):
    test_aid_to_simpids[aid].update(eval(simpids))
for aid, complpids in zip(train_compl_data.aid, train_compl_data.compl_pids):
    train_aid_to_complpids[aid].update(eval(complpids))
for aid, complpids in zip(test_compl_data.aid, test_compl_data.compl_pids):
    test_aid_to_complpids[aid].update(eval(complpids))
for qid, relpids in zip(train_search_data.qid, train_search_data.rel_pids):
    train_qid_to_pids[qid].update(eval(relpids))
for qid, relpids in tqdm(zip(test_search_data.qid, test_search_data.rel_pids), total=len(test_search_data)):
    test_qid_to_pids[qid].update(eval(relpids))
    

print("train sim_arels, compl_arels, search_qrels = {:,}, {:,}, {:,}".format(
    sum([len(x) for x in train_aid_to_simpids.values()]), sum([len(x) for x in train_aid_to_complpids.values()]), 
    sum([len(x) for x in train_qid_to_pids.values()])
))
print("test sim_arels, compl_arels, search_qrels = {:,}, {:,}, {:,}".format(
    sum([len(x) for x in test_aid_to_simpids.values()]), sum([len(x) for x in test_aid_to_complpids.values()]), 
    sum([len(x) for x in test_qid_to_pids.values()])
))

exclude_aid_to_simpids, exclude_aid_to_complpids, exclude_qid_to_pids = {}, {}, {}
for aid, simpids in test_aid_to_simpids.items():
    if aid in train_aid_to_simpids:
        exclude_pids = simpids.difference(train_aid_to_simpids[aid])
    else:
        exclude_pids = simpids
    exclude_aid_to_simpids[aid] = exclude_pids
for aid, complpids in test_aid_to_complpids.items():
    if aid in train_aid_to_complpids:
        exclude_pids = complpids.difference(train_aid_to_complpids[aid])
    else:
        exclude_pids = complpids
    exclude_aid_to_complpids[aid] = exclude_pids
for qid, pids in test_qid_to_pids.items():
    if qid in train_qid_to_pids:
        exclude_pids = pids.difference(train_qid_to_pids[qid])
    else:
        exclude_pids = pids
    exclude_qid_to_pids[qid] = exclude_pids
    
print("after difference, test sim_arels, compl_arels, search_qrels = {:,}, {:,}, {:,}".format(
    sum([len(x) for x in exclude_aid_to_simpids.values()]), sum([len(x) for x in exclude_aid_to_complpids.values()]), 
    sum([len(x) for x in exclude_qid_to_pids.values()])
))

fn_to_data = {
    os.path.join(out_dir, "exclude_aid_to_simpids.pkl"): exclude_aid_to_simpids,
    os.path.join(out_dir, "exclude_aid_to_complpids.pkl"): exclude_aid_to_complpids,
    os.path.join(out_dir, "exclude_qid_to_relpids.pkl"): exclude_qid_to_pids,
}
for fn, data in fn_to_data.items():
    with open(fn, "wb") as fout:
        pkl.dump(data, fout)

number of selected sim, compl, search users 8,166, 1,262, 10,000


100%|██████████| 330924/330924 [00:01<00:00, 177863.04it/s]


train sim_arels, compl_arels, search_qrels = 321,960, 58,493, 2,463,484
test sim_arels, compl_arels, search_qrels = 71,517, 10,521, 203,891
after difference, test sim_arels, compl_arels, search_qrels = 24,429, 5,246, 19,858


In [52]:
def get_sequential_examples(data, prefix, sample_target_value_ids=True):
    seq_examples = []
    for uid, group in tqdm(data.groupby("uid")):
        if "search_sequential" in prefix:
            qids = list(group.qid)
            group_rel_pids = group.rel_pids 
        elif "sim_rec_sequential" in prefix:
            qids = list(group.aid)
            group_rel_pids = group.sim_pids
        elif "compl_rec_sequential" in prefix:
            qids = list(group.aid)
            group_rel_pids = group.compl_pids
        else:
            raise ValueError(f"{prefix} not valid.")
            
        rel_pids = []
  sample_relpids xs in group_rel_pids:
            rel_pids.append(random.sample(eval(xs), k=1)[0]) # only sample 1 relpid 
        assert len(qids) == len(rel_pids) == len(group)
        
        if not sample_target_value_ids:
            target_value_ids = [eval(xs) for xs in group_rel_pids][1:]
        else:
            

        uid = int(uid)
        qids = [int(x) for x in qids]
        rel_pids = [int(x) for x in rel_pids]

        query_ids = qids[1:]
        context_key_ids = qids[:-1]
        context_value_ids = rel_pids[:-1]
        if sample_target_value_ids:
            target_value_ids = rel_pids[1:]
        else:
            
        assert len(query_ids) == len(context_key_ids) == len(context_value_ids) == len(target_value_ids)

        example = {"uid": uid, "query_ids": query_ids, "context_key_ids": context_key_ids, "context_value_ids": context_value_ids,
                    "target_value_ids": target_value_ids}
        
        seq_examples.append(example)
        
    return seq_examples

train_search_examples = get_sequential_examples(train_search_data, "search_sequential")
test_search_examples = get_sequential_examples(test_search_data, "search_sequential")
train_sim_examples = get_sequential_examples(train_sim_data, "sim_rec_sequential")
test_sim_examples = get_sequential_examples(test_sim_data, "sim_rec_sequential")
train_compl_examples = get_sequential_examples(train_compl_data, "compl_rec_sequential")
test_compl_examples = get_sequential_examples(test_compl_data, "compl_rec_sequential")


100%|██████████| 805946/805946 [04:14<00:00, 3162.71it/s]
100%|██████████| 9886/9886 [00:04<00:00, 2308.78it/s]
100%|██████████| 74036/74036 [00:19<00:00, 3831.65it/s]
100%|██████████| 7628/7628 [00:02<00:00, 3194.34it/s]
100%|██████████| 11424/11424 [00:04<00:00, 2552.05it/s]
100%|██████████| 1204/1204 [00:00<00:00, 4767.93it/s]


In [54]:
def create_neg_value_ids(query_ids, pos_value_ids, miss_qids, sampler=None):
    assert type(sampler) == dict
    assert len(query_ids) == len(pos_value_ids)
    neg_value_ids = []
    for qid, pos_vid in zip(query_ids, pos_value_ids):
        if qid not in sampler:
            miss_qids.add(qid)
            neg_vid = random.sample(range(2_000_000), k=1)[0]
            while neg_vid == pos_vid:
                neg_vid = random.sample(range(2_000_000), k=1)[0]
            neg_value_ids.append(neg_vid)
        else:
            neg_vid = random.sample(sampler[qid], k=1)[0]
            while neg_vid == pos_vid:
                neg_vid = random.sample(range(2_000_000), k=1)[0]
            neg_value_ids.append(neg_vid)
    
    assert len(neg_value_ids) == len(pos_value_ids)
    
    return neg_value_ids

run_path = os.path.join(in_dir, "runs/bm25.all.run")
df = pd.read_csv(run_path, sep=" ", names=["hid", "q0", "tid", "rank", "score", "model_name"])
bm25_hid_to_tids = {}
ignore_hids = set()
for hid, group in df.groupby("hid"):
    cand_tids = list(group.tid.values)
    if len(cand_tids) < 10:
        ignore_hids.add(int(hid))
    else:
        bm25_hid_to_tids[int(hid)] = [int(x) for x in cand_tids]
        
print("number of ignore hids = {}".format(len(ignore_hids)))

number of ignore hids = 6644


In [35]:
# sanity check
for fn in os.listdir(selected_dir):
    fn = os.path.join(selected_dir, fn)
    if fn.endswith(".pkl"):
        continue
    ! wc -l $fn
    ! head -n 2 $fn
    print(75*"=")

7274 /home/jupyter/unity_jointly_rec_and_search/datasets/unified_user/selected_test_user/queries.search.test.tsv
2532088	plumb bob
2860183	toilet leveling kit
6720 /home/jupyter/unity_jointly_rec_and_search/datasets/unified_user/selected_test_user/anchors.compl.test.tsv
1918558	Grip-Rite #9 x 3-in Yellow Zinc Interior Wood Screws (722-Count) ; Wood Screws
278428	Freedom Newport 3-ft H x 8-ft W White Vinyl Gothic Fence Panel ; Vinyl Fencing
8273 /home/jupyter/unity_jointly_rec_and_search/datasets/unified_user/selected_test_user/anchors.sim.test.tsv
992738	DURALENS 24-in x 48-in 7.85-sq ft Prism Ceiling Light Panels ; Ceiling Light Panels
1378791	Satori Bianco Perla 12-in x 12-in Polished Natural Stone Marble Hexagon Stone Look Floor and Wall Tile ; Tile
15708 /home/jupyter/unity_jointly_rec_and_search/datasets/unified_user/selected_test_user/aid_to_complpid.test.tsv
1918558	372243
278428	2174624
11412 /home/jupyter/unity_jointly_rec_and_search/datasets/unified_user/selected_test_user/qi

In [39]:
eid_to_text[278428], eid_to_text[2174624]

('Freedom Newport 3-ft H x 8-ft W White Vinyl Gothic Fence Panel ; Vinyl Fencing',
 'Freedom 6-ft H x 3-in W White Vinyl Fence Gate Kit ; Vinyl Fencing')

In [23]:
test_sim_data

Unnamed: 0,uid,aid,sim_pids,date_time,visit_id
153,625467,893908,[2038631],2021-07-05 14:50:23,2580158140661277636|8065034530337164545|1
154,625467,893908,[2038631],2021-07-05 14:53:51,2580158140661277636|8065034530337164545|1
155,625467,893908,[2038631],2021-07-05 14:54:57,2580158140661277636|8065034530337164545|1
156,625467,893908,[2038631],2021-07-05 14:55:59,2580158140661277636|8065034530337164545|1
157,625467,893908,[2038631],2021-07-05 14:57:08,2580158140661277636|8065034530337164545|1
...,...,...,...,...,...
81639,177799,103290,"[1660177, 145865, 465130, 1371823, 849948, 126...",2022-03-04 09:38:33,7591656150586395866|5093192628615942583|2
81644,344035,384579,"[186736, 1104942, 1254408, 657600, 756378, 171...",2022-05-09 14:56:18,2188817688176677597|755576825108835916|1
81652,629779,1904166,"[2098042, 381554, 1266943, 1084581, 731124, 15...",2022-06-01 22:51:35,5167584331604218685|4651610697708747563|1
81657,188583,2054548,[1166185],2021-10-25 03:34:07,5523687833381507832|3471855793616288331|150
