In [None]:
## This code script is to sample and generate data for representation learning on MAG networks.

In [None]:
import os
import math
import random
import json
import pickle
import itertools
import functools
from copy import deepcopy
from tqdm import tqdm
from collections import defaultdict
from typing import List, Dict, Set, Tuple
import numpy as np
random.seed(42)
np.random.seed(42)

In [None]:
def load_data(data_root:str, dataset:str, sub_dataset:str) -> Dict:
    """
    data_root: path to directory contains the data file.
    dataset: path to dataset (MAG/Amazon)
    subdataset: sub dataset name (e.g. CS, sports)

    Returns:
    data: Dict, key is the doc id, and value is data entry
    """
    # read raw data
    data_path = os.path.join(data_root, dataset, sub_dataset, 'papers_bert.json')
    with open(data_path) as f:
        data = {}
        readin = f.readlines()
        for line in tqdm(readin, desc="Loading Data..."):
            tmp = eval(line.strip())
            k = tmp['paper']
            data[k] = tmp
            data[k]['citation'] = []
    for k in data:
        refs = data[k]['reference']
        new_refs = []
        for paper in refs:
            if paper in data:
                new_refs.append(paper)
                data[paper]['citation'].append(k)
        data[k]['reference'] = new_refs
    return data

In [None]:
def build_no_intermediate(data: Dict, type: List[str], max_sample:int = 1250000) -> Set[Tuple[str, str]]:
    """
    data: dataset return by `load_data`
    type: list of length 1

    Returns:
    id_pair: set of id pairs sampled
    """
    t = type[0]
    id_pair = set()
    keys = list(data.keys())
    cnt = np.zeros(len(data))
    for i, k0 in tqdm(enumerate(keys)):
        cnt[i] = len(data[k0][t])
    ss = cnt.sum()
    prob = cnt / ss
    for idx, k0 in enumerate(tqdm(keys)):
        tmp = prob[idx] * max_sample
        num_to_sample = int(tmp)
        fl = tmp - num_to_sample
        if np.random.uniform(0, 1) <= fl:
            num_to_sample += 1
        if num_to_sample > 0:
            lst = data[k0][t]
            random.shuffle(lst)
            tmpcnt = 0
            for k1 in lst:
                if k0 != k1 and (k0, k1) not in id_pair:
                    id_pair.add((k0, k1))
                    tmpcnt += 1
                    if tmpcnt >= num_to_sample:
                        break
    print(len(id_pair))
    return id_pair 

In [None]:
def build_one_intermediate(data:Dict, type: List[str],  max_sample=1250000) -> Set[Tuple[str, str]]:
    """
    data: dataset return by `load_data`
    type: list of length 1

    Returns:
    id_pair: set of id pairs sampled
    """
    @functools.lru_cache
    def idx2coord(idx):
        xx = math.ceil(math.sqrt(2*idx+0.25)-0.5)
        yy = idx - xx * (xx-1) // 2
        return xx-1, yy-1

    def sample_random_pair_in_list(lst: List, number:int, res_set: Set):
        n = len(lst)
        pair_cnt = n*(n+1) // 2
        to_sample = np.random.permutation(pair_cnt)
        cnt = 0
        for i in range(len(to_sample)):
            idx = to_sample[i]+1
            xx, yy = idx2coord(idx)
            assert yy <= xx
            cur_pair = (lst[xx], lst[yy])
            if cur_pair not in res_set:
                res_set.add(cur_pair)
                cnt += 1
            if cnt == number:
                break
        return cnt

    t = type[0]
    id_pair = set()
    co_type = defaultdict(set)
    for k0 in tqdm(data):
        inter = data[k0][t]
        if isinstance(inter, list) or isinstance(inter, set):
            for x in inter:
                co_type[x].add(k0)
        else:
            co_type[inter].add(k0)
    keys = list(co_type.keys())
    cnt = np.zeros(len(keys))
    for i, k in enumerate(keys):
        cnt[i] = len(co_type[k])
    cnt = cnt * (cnt+1) / 2.0
    ss = cnt.sum()
    prob = cnt / ss 
    for idx, k in enumerate(tqdm(keys)):
        num_sample = int(prob[idx] * max_sample)
        deci = prob[idx] * max_sample - num_sample
        if np.random.uniform(0, 1) <= deci:
            num_sample += 1
        if num_sample >= 1:
            lst = list(co_type[k])
            true_sample = sample_random_pair_in_list(lst, num_sample, id_pair)
            # print(true_sample, num_sample)
    print(len(id_pair))
    return id_pair

In [None]:
def convert_and_dump(data: Dict, tuples: Set[Tuple[str, str]], path: str) -> None:
    """
    Dump the sampled pairs into jsonl file

    data: Dataset returned by `load_data`
    tuples: Sampled tuples
    path: path to save json file
    """
    print("Dump data to %s" % path)
    with open(path, 'w') as fout:
        for t in tqdm(tuples, desc="Processing %s" % path.split('/')[-1]):
            q, k = t
            cur = {}
            cur['q_text'] = data[q]['title']
            cur['k_text'] = data[k]['title']
            fout.write(json.dumps(cur)+'\n')

In [None]:
GENERATOR_DICT = {
    'pr': build_no_intermediate,
    'pc': build_no_intermediate,
    'pap': build_one_intermediate,
    'pvp': build_one_intermediate,
    'pcp': build_one_intermediate,
    'prp': build_one_intermediate
}

In [None]:
datasets = ['MAG'][0]
sub_datasets = ['Mathematics'][0]
base_dir = 'xxx/data/'
save_dir = f'xxx/data/{sub_datasets}/raw'

cur_d = load_data(base_dir, datasets, sub_datasets)
print(len(cur_d))

In [None]:
for k in cur_d:
    print(k)
    print(cur_d[k])
    break

In [None]:
pr = GENERATOR_DICT['pr'](cur_d, ['reference'])
convert_and_dump(cur_d, pr, os.path.join(save_dir, 'pp.jsonl'))

In [None]:
pc = GENERATOR_DICT['pc'](cur_d, ['citation'])
convert_and_dump(cur_d, pc, os.path.join(save_dir, 'pc.jsonl'))

In [None]:
pap = GENERATOR_DICT['pap'](cur_d, ['author'])
convert_and_dump(cur_d, pap, os.path.join(save_dir, 'pap.jsonl'))

In [None]:
pvp = GENERATOR_DICT['pvp'](cur_d, ['venue'])
convert_and_dump(cur_d, pvp, os.path.join(save_dir, 'pvp.jsonl'))

In [None]:
prp = GENERATOR_DICT['prp'](cur_d, ['reference'])
convert_and_dump(cur_d, prp, os.path.join(save_dir, 'prp.jsonl'))

In [None]:
pcp = GENERATOR_DICT['pcp'](cur_d, ['citation'])
convert_and_dump(cur_d, pcp, os.path.join(save_dir, 'pcp.jsonl'))