In [1]:
import os
import random
import json
import pickle
from copy import deepcopy
from tqdm import tqdm
from collections import defaultdict

import numpy as np
from transformers import BertTokenizerFast

In [2]:
random.seed(0)

In [3]:
dataset = 'amazon' 
sub_dataset='sports'

# Generate Pretraining Data

In [4]:
# read raw data
with open(f'amazon/{sub_dataset}/product.json') as f:
    raw_data = {}
    readin = f.readlines()
    for line in tqdm(readin):
        #data.append(json.loads(line))
        #data.append(eval(line.strip()))
        tmp = eval(line.strip())
        raw_data[tmp['asin']] = tmp
#random.shuffle(data)

100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 532197/532197 [01:11<00:00, 7454.22it/s]


In [5]:
# statistics on label names
label_name_stat = defaultdict(int)

for did in tqdm(raw_data):
    sample = raw_data[did]
    c_list = list(set(sum(sample['categories'], [])))
    for c in c_list:
        label_name_stat[c] += 1

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 532197/532197 [00:01<00:00, 325071.44it/s]


In [6]:
# read label name dict

label_name_dict = {}
label_name_set = set()
label_name2id_dict = {}

for n in label_name_stat:
    if label_name_stat[n] > int(0.5 * len(raw_data)):
        continue

    label_name_dict[len(label_name_dict)] = n
    label_name_set.add(n)
    label_name2id_dict[n] = len(label_name_dict) - 1

print(f'Num of unique labels:{len(label_name_set)}')

Num of unique labels:3034


In [7]:
# filter item with no text

data = {}
for idd in tqdm(raw_data):
    if 'title' in raw_data[idd]:
        data[idd] = raw_data[idd]
print(len(data))

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 532197/532197 [00:00<00:00, 964799.58it/s]

529901





In [8]:
# filter related

idd_set = set(list(data.keys()))

for idd in tqdm(data):
    if 'related' not in data[idd]:
        continue
        
    if 'also_bought' in data[idd]['related']:
        data[idd]['related']['also_bought'] = list(set(data[idd]['related']['also_bought']) & idd_set)
        
    if 'also_viewed' in data[idd]['related']:
        data[idd]['related']['also_viewed'] = list(set(data[idd]['related']['also_viewed']) & idd_set)
        
    if 'bought_together' in data[idd]['related']:
        data[idd]['related']['bought_together'] = list(set(data[idd]['related']['bought_together']) & idd_set)

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 529901/529901 [00:04<00:00, 117485.18it/s]


In [9]:
# text processing function
def text_process(text):
    p_text = ' '.join(text.split('\r\n'))
    p_text = ' '.join(p_text.split('\n\r'))
    p_text = ' '.join(p_text.split('\n'))
    p_text = ' '.join(p_text.split('\t'))
    p_text = ' '.join(p_text.split('\rm'))
    p_text = ' '.join(p_text.split('\r'))
    p_text = ''.join(p_text.split('$'))
    p_text = ''.join(p_text.split('*'))

    return p_text

In [10]:
# average edge

also_bought_cnt = 0
also_viewed_cnt = 0
bought_together_cnt = 0

also_bought_item = {}
also_viewed_item = {}
bought_together_item = {}


for idd in tqdm(data):
    if 'related' not in data[idd] or 'title' not in data[idd]:
        continue
        
    if 'also_bought' in data[idd]['related']:        
        also_bought_cnt += len(data[idd]['related']['also_bought'])
        also_bought_item[idd] = data[idd]
        
    if 'also_viewed' in data[idd]['related']:
        also_viewed_cnt += len(data[idd]['related']['also_viewed'])
        also_viewed_item[idd] = data[idd]
        
    if 'bought_together' in data[idd]['related']:
        bought_together_cnt += len(data[idd]['related']['bought_together'])
        bought_together_item[idd] = data[idd]

print(f'average also bought:{also_bought_cnt/len(data)}, averagte also viewed:{also_viewed_cnt/len(data)}, average bought together:{bought_together_cnt/len(data)}.')
print(f'also bought items:{len(also_bought_item)}, also viewed items:{len(also_viewed_item)}, bought together items:{len(bought_together_item)}')

100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1502696/1502696 [00:03<00:00, 476815.25it/s]

average also bought:4.177224135819887, averagte also viewed:6.8697906961887165, average bought together:0.3685841980014587.
also bought items:661544, also viewed items:959912, bought together items:572456





In [11]:
## split train/val/test as 8:1:1

random.seed(0)

train_pairs = []
val_pairs = []
test_pairs = []
train_pair_set = set()
item_id2idx = {}
train_neighbor = defaultdict(list)

for iid in tqdm(also_viewed_item):
    if iid not in item_id2idx:
        item_id2idx[iid] = len(item_id2idx)
    
    also_viewed = also_viewed_item[iid]['related']['also_viewed']
    random.shuffle(also_viewed)
    
    for i in range(int(len(also_viewed)*0.8)):
        train_pairs.append((iid,also_viewed[i]))
        train_pair_set.add((iid,also_viewed[i]))
        train_pair_set.add((also_viewed[i],iid))
        
        # add to item_id2idx
        if also_viewed[i] not in item_id2idx:
            item_id2idx[also_viewed[i]] = len(item_id2idx)

        # add to train_user_neighbor/train_item_neighbor
        train_neighbor[iid].append(also_viewed[i])

    for i in range(int(len(also_viewed)*0.8),int(len(also_viewed)*0.9)):
        if (iid,also_viewed[i]) in train_pair_set:
            continue
        val_pairs.append((iid,also_viewed[i]))
        assert (iid,also_viewed[i]) not in train_pair_set

        # add to item_id2idx
        if also_viewed[i] not in item_id2idx:
            item_id2idx[also_viewed[i]] = len(item_id2idx)
        
    for i in range(int(len(also_viewed)*0.9),len(also_viewed)):
        if (iid,also_viewed[i]) in train_pair_set:
            continue
        test_pairs.append((iid,also_viewed[i]))
        assert (iid,also_viewed[i]) not in train_pair_set
        
        # add to item_id2idx
        if also_viewed[i] not in item_id2idx:
            item_id2idx[also_viewed[i]] = len(item_id2idx)
        
print(f'Train/Val/Test size:{len(train_pairs)},{len(val_pairs)},{len(test_pairs)}')
print(f'Train/Val/Test avg:{len(train_pairs)/len(also_viewed_item)},{len(val_pairs)/len(also_viewed_item)},{len(test_pairs)/len(also_viewed_item)}')

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 959912/959912 [00:31<00:00, 30873.49it/s]

Train/Val/Test size:7876427,902293,1402757
Train/Val/Test avg:8.205363616664862,0.939974706014718,1.4613391644234055





In [12]:
sample_neighbor_num = 5

In [13]:
# save all the text on node in the graph

node_id_set = set()

with open(f'data_dir/{dataset}/{sub_dataset}/corpus.txt','w') as fout:    
    for iid in tqdm(also_viewed_item):
        also_viewed = also_viewed_item[iid]['related']['also_viewed']
        
        # save iid text
        if iid not in node_id_set:
            node_id_set.add(idd)
            if 'description' in data[iid]:
                tmp_text = text_process(data[iid]['title'] + ' ' + data[iid]['description'])
            else:
                tmp_text = text_process(data[iid]['title'])
            if tmp_text not in ['', 'null']:
                fout.write(idd+'\t'+tmp_text+'\n')
    
        # save neighbor
        for iid_n in also_viewed:
            if iid_n not in node_id_set:
                node_id_set.add(iid_n)
                if 'description' in data[iid_n]:
                    tmp_text = text_process(data[iid_n]['title'] + ' ' + data[iid_n]['description'])
                else:
                    tmp_text = text_process(data[iid_n]['title'])
                if tmp_text not in ['', 'null']:
                    fout.write(iid_n+'\t'+tmp_text+'\n')

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 279788/279788 [00:03<00:00, 89284.07it/s]


In [14]:
# generate and save train file

random.seed(0)
sample_neighbor_num = 5

with open(f'data_dir/{dataset}/{sub_dataset}/train.text.jsonl','w') as fout:
    for (q, k) in tqdm(train_pairs):
        
        # prepare sample pool for item
        q_n_pool = set(deepcopy(train_neighbor[q]))
        k_n_pool = set(deepcopy(train_neighbor[k]))

        if k in q_n_pool:
            q_n_pool.remove(k)
        if q in k_n_pool:
            k_n_pool.remove(q)

        q_n_pool = list(q_n_pool)
        k_n_pool = list(k_n_pool)
        random.shuffle(q_n_pool)
        random.shuffle(k_n_pool)
        
        # sample neighbor
        if len(q_n_pool) >= sample_neighbor_num:
            q_samples = q_n_pool[:sample_neighbor_num]
        else:
            q_samples = q_n_pool + [-1] * (sample_neighbor_num-len(q_n_pool))
        
        if len(k_n_pool) >= sample_neighbor_num:
            k_samples = k_n_pool[:sample_neighbor_num]
        else:
            k_samples = k_n_pool + [-1] * (sample_neighbor_num-len(k_n_pool))
        
        # prepare for writing file
        if 'description' in data[q]:
            q_text = text_process(data[q]['title'] + ' ' + data[q]['description'])
        else:
            q_text = text_process(data[q]['title'])
        q_n_text = []
        for q_n in q_samples:
            if q_n == -1:
                q_n_text.append('')
            elif 'description' in data[q_n]:
                q_n_text.append(text_process(data[q_n]['title'] + ' ' + data[q_n]['description']))
            else:
                q_n_text.append(text_process(data[q_n]['title']))

        if 'description' in data[k]:
            k_text = text_process(data[k]['title'] + ' ' + data[k]['description'])
        else:
            k_text = text_process(data[k]['title'])
        k_n_text = []
        for k_n in k_samples:
            if k_n == -1:
                k_n_text.append('')
            elif 'description' in data[k_n]:
                k_n_text.append(text_process(data[k_n]['title'] + ' ' + data[k_n]['description']))
            else:
                k_n_text.append(text_process(data[k_n]['title']))
        
        #fout.write(q_line+'\t'+k_line+'\n')
        fout.write(json.dumps({
            'q_text':q_text,
            'q_n_text':q_n_text,
            'k_text':k_text,
            'k_n_text':k_n_text,
        })+'\n')

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4479834/4479834 [16:43<00:00, 4462.82it/s]


In [15]:
# generate and save val file (make sure to delete items that are not in train set)

random.seed(0)

with open(f'data_dir/{dataset}/{sub_dataset}/val.text.jsonl','w') as fout:
    for (q, k) in tqdm(val_pairs):
        
        # prepare sample pool for item
        q_n_pool = set(deepcopy(train_neighbor[q]))
        k_n_pool = set(deepcopy(train_neighbor[k]))

        if k in q_n_pool:
            q_n_pool.remove(k)
        if q in k_n_pool:
            k_n_pool.remove(q)

        q_n_pool = list(q_n_pool)
        k_n_pool = list(k_n_pool)
        random.shuffle(q_n_pool)
        random.shuffle(k_n_pool)
        
        # sample neighbor
        if len(q_n_pool) >= sample_neighbor_num:
            q_samples = q_n_pool[:sample_neighbor_num]
        else:
            q_samples = q_n_pool + [-1] * (sample_neighbor_num-len(q_n_pool))
        
        if len(k_n_pool) >= sample_neighbor_num:
            k_samples = k_n_pool[:sample_neighbor_num]
        else:
            k_samples = k_n_pool + [-1] * (sample_neighbor_num-len(k_n_pool))
        
        # prepare for writing file
        if 'description' in data[q]:
            q_text = text_process(data[q]['title'] + ' ' + data[q]['description'])
        else:
            q_text = text_process(data[q]['title'])
        q_n_text = []
        for q_n in q_samples:
            if q_n == -1:
                q_n_text.append('')
            elif 'description' in data[q_n]:
                q_n_text.append(text_process(data[q_n]['title'] + ' ' + data[q_n]['description']))
            else:
                q_n_text.append(text_process(data[q_n]['title']))

        if 'description' in data[k]:
            k_text = text_process(data[k]['title'] + ' ' + data[k]['description'])
        else:
            k_text = text_process(data[k]['title'])
        k_n_text = []
        for k_n in k_samples:
            if k_n == -1:
                k_n_text.append('')
            elif 'description' in data[k_n]:
                k_n_text.append(text_process(data[k_n]['title'] + ' ' + data[k_n]['description']))
            else:
                k_n_text.append(text_process(data[k_n]['title']))
        
        #fout.write(q_line+'\t'+k_line+'\n')
        fout.write(json.dumps({
            'q_text':q_text,
            'q_n_text':q_n_text,
            'k_text':k_text,
            'k_n_text':k_n_text,
        })+'\n')

100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 457776/457776 [01:40<00:00, 4559.57it/s]


In [16]:
# generate and save test file (make sure to delete items that are not in train set)

random.seed(0)

with open(f'data_dir/{dataset}/{sub_dataset}/test.text.jsonl','w') as fout:
    for (q, k) in tqdm(test_pairs):
        
        # prepare sample pool for item
        q_n_pool = set(deepcopy(train_neighbor[q]))
        k_n_pool = set(deepcopy(train_neighbor[k]))

        if k in q_n_pool:
            q_n_pool.remove(k)
        if q in k_n_pool:
            k_n_pool.remove(q)

        q_n_pool = list(q_n_pool)
        k_n_pool = list(k_n_pool)
        random.shuffle(q_n_pool)
        random.shuffle(k_n_pool)
        
        # sample neighbor
        if len(q_n_pool) >= sample_neighbor_num:
            q_samples = q_n_pool[:sample_neighbor_num]
        else:
            q_samples = q_n_pool + [-1] * (sample_neighbor_num-len(q_n_pool))
        
        if len(k_n_pool) >= sample_neighbor_num:
            k_samples = k_n_pool[:sample_neighbor_num]
        else:
            k_samples = k_n_pool + [-1] * (sample_neighbor_num-len(k_n_pool))
        
        # prepare for writing file
        if 'description' in data[q]:
            q_text = text_process(data[q]['title'] + ' ' + data[q]['description'])
        else:
            q_text = text_process(data[q]['title'])
        q_n_text = []
        for q_n in q_samples:
            if q_n == -1:
                q_n_text.append('')
            elif 'description' in data[q_n]:
                q_n_text.append(text_process(data[q_n]['title'] + ' ' + data[q_n]['description']))
            else:
                q_n_text.append(text_process(data[q_n]['title']))

        if 'description' in data[k]:
            k_text = text_process(data[k]['title'] + ' ' + data[k]['description'])
        else:
            k_text = text_process(data[k]['title'])
        k_n_text = []
        for k_n in k_samples:
            if k_n == -1:
                k_n_text.append('')
            elif 'description' in data[k_n]:
                k_n_text.append(text_process(data[k_n]['title'] + ' ' + data[k_n]['description']))
            else:
                k_n_text.append(text_process(data[k_n]['title']))
        
        #fout.write(q_line+'\t'+k_line+'\n')
        fout.write(json.dumps({
            'q_text':q_text,
            'q_n_text':q_n_text,
            'k_text':k_text,
            'k_n_text':k_n_text,
        })+'\n')

100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 589610/589610 [02:13<00:00, 4401.68it/s]


In [17]:
# save side files
pickle.dump([sample_neighbor_num],open(f'data_dir/{dataset}/{sub_dataset}/neighbor_sampling.pkl','wb'))

In [18]:
# save neighbor file
pickle.dump(train_neighbor,open(f'data_dir/{dataset}/{sub_dataset}/neighbor/train_neighbor.pkl','wb'))

In [19]:
# save node labels
random.seed(0)

with open(f'data_dir/{dataset}/{sub_dataset}/nc/node_classification.jsonl','w') as fout:
    for q in tqdm(also_viewed_item):
        
        # prepare sample pool for item
        q_n_pool = set(deepcopy(train_neighbor[q]))

        q_n_pool = list(q_n_pool)
        random.shuffle(q_n_pool)
        
        # sample neighbor
        if len(q_n_pool) >= sample_neighbor_num:
            q_samples = q_n_pool[:sample_neighbor_num]
        else:
            q_samples = q_n_pool + [-1] * (sample_neighbor_num-len(q_n_pool))
        
        # prepare for writing file
        if 'description' in data[q]:
            q_text = text_process(data[q]['title'] + ' ' + data[q]['description'])
        else:
            q_text = text_process(data[q]['title'])
        q_n_text = []
        for q_n in q_samples:
            if q_n == -1:
                q_n_text.append('')
            elif 'description' in data[q_n]:
                q_n_text.append(text_process(data[q_n]['title'] + ' ' + data[q_n]['description']))
            else:
                q_n_text.append(text_process(data[q_n]['title']))
        
        
        label_names_list = list(set(sum(data[q]['categories'], [])))
        label_names_list = [n for n in label_names_list if n in label_name2id_dict]
        label_ids_list = [label_name2id_dict[n] for n in label_names_list]
        
        if len(label_ids_list) != 0:
            fout.write(json.dumps({
                'q_text':q_text,
                'q_n_text':q_n_text,
                'labels':label_ids_list,
                'label_names':label_names_list
            })+'\n')

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 279788/279788 [00:22<00:00, 12297.67it/s]


In [20]:
# generate self constrastive pretraining

corpus_list = []

with open(f'data_dir/{dataset}/{sub_dataset}/corpus.txt') as f:
    readin = f.readlines()
    for line in tqdm(readin):
        tmp = line.strip().split('\t')
        corpus_list.append(tmp[1])
        
with open(f'data_dir/{dataset}/{sub_dataset}/self-train/train.text.jsonl','w') as fout:
    for dd in tqdm(corpus_list):
        fout.write(json.dumps({
            'q_text':dd,
            'q_n_text':[''],
            'k_text':dd,
            'k_n_text':[''],
        })+'\n')

with open(f'data_dir/{dataset}/{sub_dataset}/self-train/val.text.jsonl','w') as fout:
    for dd in tqdm(corpus_list[:int(0.2*len(corpus_list))]):
        fout.write(json.dumps({
            'q_text':dd,
            'q_n_text':[''],
            'k_text':dd,
            'k_n_text':[''],
        })+'\n')
        
with open(f'data_dir/{dataset}/{sub_dataset}/self-train/test.text.jsonl','w') as fout:
    for dd in tqdm(corpus_list[int(0.8*len(corpus_list)):]):
        fout.write(json.dumps({
            'q_text':dd,
            'q_n_text':[''],
            'k_text':dd,
            'k_n_text':[''],
        })+'\n')

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 309311/309311 [00:00<00:00, 706064.65it/s]
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 309311/309311 [00:03<00:00, 88608.53it/s]
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 61862/61862 [00:00<00:00, 86549.53it/s]
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 61863/61863 [00:00<00:00, 91207.12it/s]


## Generate node classification data for retrieval and reranking

In [21]:
# write labels into documents.json

labels_dict = []
#for lid in label_name_dict:
for lname in label_name2id_dict:
    if lname != 'null':
        labels_dict.append({'id':label_name2id_dict[lname], 'contents':lname})
json.dump(labels_dict, open(f'data_dir/amazon/{sub_dataset}/nc/documents.json', 'w'), indent=4)

with open(f'data_dir/amazon/{sub_dataset}/nc/documents.txt', 'w') as fout:
    #for lid in label_name_dict:
    for lname in label_name2id_dict:
        if lname == 'null':
            continue
        fout.write(str(label_name2id_dict[lname])+'\t'+lname+'\n')

In [22]:
# generate node query file & ground truth file

docid = 0

with open(f'data_dir/amazon/{sub_dataset}/nc/node_classification.jsonl') as f, open(f'data_dir/amazon/{sub_dataset}/nc/node_text.tsv', 'w') as fout1, open(f'data_dir/amazon/{sub_dataset}/nc/truth.trec', 'w') as fout2:
    readin = f.readlines()
    for line in tqdm(readin):
        tmp = json.loads(line)
        fout1.write(str(docid) + '\t' + tmp['q_text'] + '\n')
        for label in tmp['labels']:
            fout2.write(str(docid)+' '+str(0)+' '+str(label)+' '+str(1)+'\n')
        docid += 1

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 274354/274354 [00:05<00:00, 48064.64it/s]


In [23]:
# generate node query file & ground truth file
## You can just skip this cell if you want to use bm25 negative

docid = 0

with open(f'data_dir/amazon/{sub_dataset}/nc/node_classification.jsonl') as f, open(f'data_dir/amazon/{sub_dataset}/nc/train.text.jsonl', 'w') as fout1, open(f'data_dir/amazon/{sub_dataset}/nc/val.text.jsonl', 'w') as fout2, open(f'data_dir/amazon/{sub_dataset}/nc/test.truth.trec', 'w') as fout3, open(f'data_dir/amazon/{sub_dataset}/nc/test.node.text.jsonl', 'w') as fout4:
    readin = f.readlines()
    total_len = len(readin)
    for line in tqdm(readin[:int(0.8*total_len)]):
        tmp = json.loads(line)
        for label_name in tmp['label_names']:
            fout1.write(json.dumps({
                'q_text':tmp['q_text'],
                'q_n_text':tmp['q_n_text'],
                'k_text':label_name,
                'k_n_text':[''],
            })+'\n')
        docid += 1
    
    for line in tqdm(readin[int(0.8*total_len):int(0.9*total_len)]):
        tmp = json.loads(line)
        for label_name in tmp['label_names']:
            fout2.write(json.dumps({
                'q_text':tmp['q_text'],
                'q_n_text':tmp['q_n_text'],
                'k_text':label_name,
                'k_n_text':[''],
            })+'\n')
        docid += 1
        
    for line in tqdm(readin[int(0.9*total_len):]):
        tmp = json.loads(line)
        #fout4.write(str(docid) + '\t' + tmp['q_text'] + '\n')
        fout4.write(json.dumps({
                'id': str(docid),
                'text':tmp['q_text'],
                'n_text':tmp['q_n_text']
            })+'\n')
        for label in tmp['labels']:
            fout3.write(str(docid)+' '+str(0)+' '+str(label)+' '+str(1)+'\n')
        docid += 1

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 219483/219483 [00:14<00:00, 15537.70it/s]
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 27435/27435 [00:01<00:00, 14991.04it/s]
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 27436/27436 [00:01<00:00, 25996.62it/s]


## Generate Coarse-grained Classification Data

In [9]:
# read label name dict
coarse_label_id2name = {}
#coarse_label_id2idx = {}

with open(f'data_dir/amazon/{sub_dataset}/coarse_class.txt') as f:
    readin = f.readlines()
    for line in tqdm(readin):
        tmp = line.strip().split('\t')
        if tmp[1] not in label_name2id_dict:
            continue
        coarse_label_id2name[label_name2id_dict[tmp[1]]] = tmp[1]

print(f'Num of unique labels:{len(coarse_label_id2name)};{coarse_label_id2name}')

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 16/16 [00:00<00:00, 94121.83it/s]

Num of unique labels:16;{6: 'Accessories', 116: 'Action Sports', 27: 'Boating & Water Sports', 2: 'Clothing', 30: 'Cycling', 45: 'Equestrian Sports', 23: 'Exercise & Fitness', 42: 'Fan Shop', 15: 'Golf', 8: 'Hunting & Fishing', 18: 'Leisure Sports & Game Room', 12: 'Outdoor Gear', 154: 'Paintball & Airsoft', 410: 'Racquet Sports', 144: 'Snow Sports', 130: 'Team Sports'}





In [28]:
# generate train/val/test file
# filter out and only use node which has single label

ktrain = 8 # train sample threshold, how many training samples do we have for each class
kdev = 8 # dev sample threshold, how many dev samples do we have for each class
label_samples = defaultdict(list)

with open(f'data_dir/amazon/{sub_dataset}/nc/node_classification.jsonl') as f:
    readin = f.readlines()
    for line in tqdm(readin):
        tmp = json.loads(line)
        inter_label = list(set(tmp['labels']) & set(coarse_label_id2name))
        if len(inter_label) == 1:
            label_samples[inter_label[0]].append(tmp)
            
# select labels
coarse_label_id2idx = {}
for l in label_samples:
    if len(label_samples[l]) > ktrain + kdev:
        coarse_label_id2idx[l] = len(coarse_label_id2idx)
        
print(f'Num of unique labels:{len(coarse_label_id2idx)};{coarse_label_id2idx}')

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 349783/349783 [00:08<00:00, 39081.02it/s]

Num of unique labels:14;{2: 0, 6: 1, 8: 2, 12: 3, 15: 4, 27: 5, 30: 6, 23: 7, 18: 8, 116: 9, 130: 10, 154: 11, 144: 12, 45: 13}





In [29]:
# for coarse-grained classification only
random.seed(12)

In [30]:
# save

if not os.path.exists(f'data_dir/amazon/{sub_dataset}/nc-coarse/{str(ktrain)}_{str(kdev)}'):
    os.mkdir(f'data_dir/amazon/{sub_dataset}/nc-coarse/{str(ktrain)}_{str(kdev)}')

with open(f'data_dir/amazon/{sub_dataset}/nc-coarse/{str(ktrain)}_{str(kdev)}/train.text.jsonl', 'w') as fout1, open(f'data_dir/amazon/{sub_dataset}/nc-coarse/{str(ktrain)}_{str(kdev)}/val.text.jsonl', 'w') as fout2, open(f'data_dir/amazon/{sub_dataset}/nc-coarse/{str(ktrain)}_{str(kdev)}/test.text.jsonl', 'w') as fout3:
    
    assert ktrain+kdev <= 32
    
    for l in coarse_label_id2idx:
        random.shuffle(label_samples[l])
        
        train_data = label_samples[l][:ktrain]
        dev_data = label_samples[l][ktrain:(ktrain+kdev)]
        #test_data = label_samples[l][(ktrain+kdev):]
        test_data = label_samples[l][16:]
    
        # write train
        for d in train_data:
            fout1.write(json.dumps({
                'q_text':d['q_text'],
                'q_n_text':d['q_n_text'],
                'label':coarse_label_id2idx[l]
            })+'\n')
    
        # write dev
        for d in dev_data:
            fout2.write(json.dumps({
                'q_text':d['q_text'],
                'q_n_text':d['q_n_text'],
                'label':coarse_label_id2idx[l]
            })+'\n')
    
        # write test
        for d in test_data:
            fout3.write(json.dumps({
                'q_text':d['q_text'],
                'q_n_text':d['q_n_text'],
                'label':coarse_label_id2idx[l]
            })+'\n')

pickle.dump(coarse_label_id2idx, open(f'data_dir/amazon/{sub_dataset}/nc-coarse/coarse_label_id2idx.pkl', 'wb'))
pickle.dump([ktrain, kdev], open(f'data_dir/amazon/{sub_dataset}/nc-coarse/threshold.pkl', 'wb'))

In [31]:
coarse_label_idx2id = {coarse_label_id2idx[idd]:idd for idd in coarse_label_id2idx}
with open(f'data_dir/amazon/{sub_dataset}/nc-coarse/{str(ktrain)}_{str(kdev)}/label_name.txt','w') as fout:
    for i in range(len(coarse_label_idx2id)):
        fout.write(coarse_label_id2name[coarse_label_idx2id[i]]+'\n')

## Downstream link prediction (co-purchase)

In [27]:
## split train/val/test as 8:1:1

random.seed(0)

train_pairs = []
val_pairs = []
test_pairs = []
train_pair_set = set()

for iid in tqdm(also_bought_item):
    if iid not in item_id2idx:
        continue
    
    also_bought = also_bought_item[iid]['related']['also_bought']
    random.shuffle(also_bought)
    
    for i in range(int(len(also_bought)*0.8)):
        if also_bought[i] in item_id2idx:
            train_pairs.append((iid,also_bought[i]))
            train_pair_set.add((iid,also_bought[i]))
            train_pair_set.add((also_bought[i],iid))

    for i in range(int(len(also_bought)*0.8),int(len(also_bought)*0.9)):
        if also_bought[i] in item_id2idx:
            if (iid,also_bought[i]) in train_pair_set:
                continue
            val_pairs.append((iid,also_bought[i]))
            assert (iid,also_bought[i]) not in train_pair_set

    for i in range(int(len(also_bought)*0.9),len(also_bought)):
        if also_bought[i] in item_id2idx:
            if (iid,also_bought[i]) in train_pair_set:
                continue
            test_pairs.append((iid,also_bought[i]))
            assert (iid,also_bought[i]) not in train_pair_set

print(f'Train/Val/Test size:{len(train_pairs)},{len(val_pairs)},{len(test_pairs)}')
print(f'Train/Val/Test avg:{len(train_pairs)/len(also_bought_item)},{len(val_pairs)/len(also_bought_item)},{len(test_pairs)/len(also_bought_item)}')

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 123480/123480 [00:03<00:00, 35504.31it/s]

Train/Val/Test size:1320688,124040,169599
Train/Val/Test avg:10.695562034337545,1.0045351473922903,1.3734936831875608





In [28]:
# generate and save train file

random.seed(0)
sample_neighbor_num = 5

with open(f'data_dir/{dataset}/{sub_dataset}/also_bought/train.text.jsonl','w') as fout:
    for (q, k) in tqdm(train_pairs):
        
        # prepare sample pool for item
        q_n_pool = set(deepcopy(train_neighbor[q]))
        k_n_pool = set(deepcopy(train_neighbor[k]))

        if k in q_n_pool:
            q_n_pool.remove(k)
        if q in k_n_pool:
            k_n_pool.remove(q)

        q_n_pool = list(q_n_pool)
        k_n_pool = list(k_n_pool)
        random.shuffle(q_n_pool)
        random.shuffle(k_n_pool)
        
        # sample neighbor
        if len(q_n_pool) >= sample_neighbor_num:
            q_samples = q_n_pool[:sample_neighbor_num]
        else:
            q_samples = q_n_pool + [-1] * (sample_neighbor_num-len(q_n_pool))
        
        if len(k_n_pool) >= sample_neighbor_num:
            k_samples = k_n_pool[:sample_neighbor_num]
        else:
            k_samples = k_n_pool + [-1] * (sample_neighbor_num-len(k_n_pool))
        
        # prepare for writing file
        if 'description' in data[q]:
            q_text = text_process(data[q]['title'] + ' ' + data[q]['description'])
        else:
            q_text = text_process(data[q]['title'])
        q_n_text = []
        for q_n in q_samples:
            if q_n == -1:
                q_n_text.append('')
            elif 'description' in data[q_n]:
                q_n_text.append(text_process(data[q_n]['title'] + ' ' + data[q_n]['description']))
            else:
                q_n_text.append(text_process(data[q_n]['title']))

        if 'description' in data[k]:
            k_text = text_process(data[k]['title'] + ' ' + data[k]['description'])
        else:
            k_text = text_process(data[k]['title'])
        k_n_text = []
        for k_n in k_samples:
            if k_n == -1:
                k_n_text.append('')
            elif 'description' in data[k_n]:
                k_n_text.append(text_process(data[k_n]['title'] + ' ' + data[k_n]['description']))
            else:
                k_n_text.append(text_process(data[k_n]['title']))
        
        #fout.write(q_line+'\t'+k_line+'\n')
        fout.write(json.dumps({
            'q_text':q_text,
            'q_n_text':q_n_text,
            'k_text':k_text,
            'k_n_text':k_n_text,
        })+'\n')

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1320688/1320688 [05:00<00:00, 4394.21it/s]


In [29]:
# generate and save val file (make sure to delete items that are not in train set)

random.seed(0)

with open(f'data_dir/{dataset}/{sub_dataset}/also_bought/val.text.jsonl','w') as fout:
    for (q, k) in tqdm(val_pairs):
        
        # prepare sample pool for item
        q_n_pool = set(deepcopy(train_neighbor[q]))
        k_n_pool = set(deepcopy(train_neighbor[k]))

        if k in q_n_pool:
            q_n_pool.remove(k)
        if q in k_n_pool:
            k_n_pool.remove(q)

        q_n_pool = list(q_n_pool)
        k_n_pool = list(k_n_pool)
        random.shuffle(q_n_pool)
        random.shuffle(k_n_pool)
        
        # sample neighbor
        if len(q_n_pool) >= sample_neighbor_num:
            q_samples = q_n_pool[:sample_neighbor_num]
        else:
            q_samples = q_n_pool + [-1] * (sample_neighbor_num-len(q_n_pool))
        
        if len(k_n_pool) >= sample_neighbor_num:
            k_samples = k_n_pool[:sample_neighbor_num]
        else:
            k_samples = k_n_pool + [-1] * (sample_neighbor_num-len(k_n_pool))
        
        # prepare for writing file
        if 'description' in data[q]:
            q_text = text_process(data[q]['title'] + ' ' + data[q]['description'])
        else:
            q_text = text_process(data[q]['title'])
        q_n_text = []
        for q_n in q_samples:
            if q_n == -1:
                q_n_text.append('')
            elif 'description' in data[q_n]:
                q_n_text.append(text_process(data[q_n]['title'] + ' ' + data[q_n]['description']))
            else:
                q_n_text.append(text_process(data[q_n]['title']))

        if 'description' in data[k]:
            k_text = text_process(data[k]['title'] + ' ' + data[k]['description'])
        else:
            k_text = text_process(data[k]['title'])
        k_n_text = []
        for k_n in k_samples:
            if k_n == -1:
                k_n_text.append('')
            elif 'description' in data[k_n]:
                k_n_text.append(text_process(data[k_n]['title'] + ' ' + data[k_n]['description']))
            else:
                k_n_text.append(text_process(data[k_n]['title']))
        
        #fout.write(q_line+'\t'+k_line+'\n')
        fout.write(json.dumps({
            'q_text':q_text,
            'q_n_text':q_n_text,
            'k_text':k_text,
            'k_n_text':k_n_text,
        })+'\n')

100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 124040/124040 [00:20<00:00, 5940.61it/s]


In [30]:
# generate and save test file (make sure to delete items that are not in train set)

random.seed(0)

with open(f'data_dir/{dataset}/{sub_dataset}/also_bought/test.text.jsonl','w') as fout:
    for (q, k) in tqdm(test_pairs):
        
        # prepare sample pool for item
        q_n_pool = set(deepcopy(train_neighbor[q]))
        k_n_pool = set(deepcopy(train_neighbor[k]))

        if k in q_n_pool:
            q_n_pool.remove(k)
        if q in k_n_pool:
            k_n_pool.remove(q)

        q_n_pool = list(q_n_pool)
        k_n_pool = list(k_n_pool)
        random.shuffle(q_n_pool)
        random.shuffle(k_n_pool)
        
        # sample neighbor
        if len(q_n_pool) >= sample_neighbor_num:
            q_samples = q_n_pool[:sample_neighbor_num]
        else:
            q_samples = q_n_pool + [-1] * (sample_neighbor_num-len(q_n_pool))
        
        if len(k_n_pool) >= sample_neighbor_num:
            k_samples = k_n_pool[:sample_neighbor_num]
        else:
            k_samples = k_n_pool + [-1] * (sample_neighbor_num-len(k_n_pool))
        
        # prepare for writing file
        if 'description' in data[q]:
            q_text = text_process(data[q]['title'] + ' ' + data[q]['description'])
        else:
            q_text = text_process(data[q]['title'])
        q_n_text = []
        for q_n in q_samples:
            if q_n == -1:
                q_n_text.append('')
            elif 'description' in data[q_n]:
                q_n_text.append(text_process(data[q_n]['title'] + ' ' + data[q_n]['description']))
            else:
                q_n_text.append(text_process(data[q_n]['title']))

        if 'description' in data[k]:
            k_text = text_process(data[k]['title'] + ' ' + data[k]['description'])
        else:
            k_text = text_process(data[k]['title'])
        k_n_text = []
        for k_n in k_samples:
            if k_n == -1:
                k_n_text.append('')
            elif 'description' in data[k_n]:
                k_n_text.append(text_process(data[k_n]['title'] + ' ' + data[k_n]['description']))
            else:
                k_n_text.append(text_process(data[k_n]['title']))
        
        #fout.write(q_line+'\t'+k_line+'\n')
        fout.write(json.dumps({
            'q_text':q_text,
            'q_n_text':q_n_text,
            'k_text':k_text,
            'k_n_text':k_n_text,
        })+'\n')

100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 169599/169599 [00:27<00:00, 6063.63it/s]
