In [None]:
import os
import math
import random
import json
import pickle
import itertools
import functools
from copy import deepcopy
from tqdm import tqdm
import seaborn as sns
from collections import defaultdict
from typing import List, Dict, Set, Tuple
import numpy as np
random.seed(42)
np.random.seed(42)

## Load data

In [None]:
def load_data(data_root:str, dataset:str, sub_dataset:str) -> Dict:
    """
    data_root: path to directory contains the data file.
    dataset: path to dataset (MAG)
    subdataset: sub dataset name (e.g. CS)

    Returns:
    data: Dict, key is the doc id, and value is data entry
    """
    # read raw data
    data_path = os.path.join(data_root, dataset, sub_dataset, 'papers_bert.json')
    with open(data_path) as f:
        data = {}
        readin = f.readlines()
        for line in tqdm(readin, desc="Loading Data..."):
            tmp = eval(line.strip())
            k = tmp['paper']
            data[k] = tmp
            data[k]['citation'] = []
    for k in data:
        refs = data[k]['reference']
        new_refs = []
        for paper in refs:
            if paper in data:
                new_refs.append(paper)
                data[paper]['citation'].append(k)
        data[k]['reference'] = new_refs
    return data

In [None]:
def load_meta_data(base_dir, data_name):
    p = os.path.join(base_dir, data_name, 'venues.txt')
    vid2name = defaultdict(str)
    with open(p) as f:
        for l in f:
            id, _, name = l.strip().split('\t')
            vid2name[id] = name
    return vid2name

In [None]:
def load_label(data_path, subdataset):
    path = os.path.join(data_path, 'MAG', subdataset, 'labels.txt')
    id2label = dict()
    id2level = dict()
    with open(path) as f:
        for l in tqdm(f):
            if len(l.strip()) > 0:
                lid, name, level = l.strip().split('\t')
                id2label[lid] = name
                id2level[lid] = int(level)
    return id2label, id2level

In [None]:
meta_base_dir = 'xxx/'
raw_base_dir = 'xxx/data/'
save_base_dir = 'xxx/data/'
data_name = ['Geology', 'Mathematics'][1]
meta_data_dir = os.path.join(meta_base_dir, data_name)
raw_data_dir = os.path.join(raw_base_dir, data_name)

In [None]:
data = load_data(raw_base_dir, 'MAG', data_name)

In [None]:
vid2n = load_meta_data(meta_base_dir, data_name)

In [None]:
id2label,id2level = load_label(raw_base_dir, data_name)

## paper recommendation

In [None]:
def load_raw_user_action_data(base_dir, subdataset):
    p = os.path.join(base_dir, subdataset, 'raw', 'PaperRecommendation.txt')
    ret = []
    with open(p) as f:
        for l in tqdm(f, desc='Loading user action data of %s...' % subdataset):
            res = l.strip().split('\t')
            ret.append(res)
    return ret

In [None]:
def generate_paper_recommendation(paper_data, user_data, threshold):
    res = []
    for item in tqdm(user_data, desc='Generating paper recommendation data'):
        if len(item) == 3:
            pa, pb, score = item
            if 'title' in paper_data[pa] and 'title' in paper_data[pb]:
                ta = paper_data[pa]['title'].strip()
                tb = paper_data[pb]['title'].strip()
                if len(ta) > 0 and len(tb) > 0 and float(score) > threshold:
                    res.append((ta, tb, score))
    return res        

In [None]:
def write_paper_recommendation(save_base_dir, data_name, data):
    tmp_base = os.path.join(save_base_dir, data_name, 'downstream', 'paper_recommendation')
    if not os.path.exists(tmp_base):
        os.makedirs(tmp_base)
    data_path = os.path.join(tmp_base, 'data.jsonl')
    with open(data_path, 'w') as f:
        for item in tqdm(data, desc="Write data to %s" % data_path):
            f.write(
                json.dumps({
                    'q_text': item[0],
                    'k_text': item[1],
                    'score': float(item[2])
                }) + '\n'
            )    

In [None]:
user_action_data = load_raw_user_action_data(save_base_dir, data_name)

In [None]:
len(user_action_data)

In [None]:
tmp = generate_paper_recommendation(data, user_action_data, 0.9)
print(len(tmp))

In [None]:
write_paper_recommendation(save_base_dir, data_name, tmp)

## author disambiguation

In [None]:
def generate_author_disambiguation(data):
    dict_author_id = defaultdict(list)
    res = set()
    res3 = defaultdict(dict)
    for k in tqdm(data):
        if 'author' in data[k]:
            for tmp_a in data[k]['author']:
                res.add((data[k]['title'], tmp_a))
                dict_author_id[tmp_a].append(data[k]['title'])
    print(len(dict_author_id))
    for k in dict_author_id:
        if len(dict_author_id[k]) > 100:
            tmp = tuple(random.sample(dict_author_id[k], 100))
        else:
            tmp = tuple(dict_author_id[k])
        res3[k] = {'id': k, 'paper': tmp}
    print(len(res), len(res3))
    return res, res3

In [None]:
def write_author_advanced(save_base_dir, data_name, data, author_info_dict):
    tmp_base = os.path.join(save_base_dir, data_name, 'downstream', 'author_disambiguation')
    if not os.path.exists(tmp_base):
        os.makedirs(tmp_base)
    
    # author paper matching
    data_ap_path = os.path.join(tmp_base, 'data_ap.jsonl')
    with open(data_ap_path, 'w') as f:
        for pname, aid in tqdm(data):
            tmp_paper = list(author_info_dict[aid]['paper'])
            if len(tmp_paper) == 1:
                continue
            if pname in tmp_paper:
                tmp_paper.remove(pname)
            random.shuffle(tmp_paper)
            dd = {'q_text': pname, 'k_text': ' '.join(tmp_paper)}
            ddr = json.dumps(dd)
            f.write(ddr + '\n')

In [None]:
d, vinfo = generate_author_disambiguation(data)

In [None]:
write_author_advanced(save_base_dir, data_name, d, vinfo)

## venue recommendation

In [None]:
def generate_venue_recommendation(data, vid2n):
    dict_venue_id = defaultdict(list)
    res = set()
    res2 = []
    res3 = defaultdict(dict)
    for k in tqdm(data):
        if 'venue' in data[k]:
            res.add((data[k]['title'], data[k]['venue']))
            dict_venue_id[data[k]['venue']].append(data[k]['title'])
    print(len(dict_venue_id))
    for k in dict_venue_id:
        n = vid2n[k]
        if len(dict_venue_id[k]) > 100:
            tmp = tuple(random.sample(dict_venue_id[k], 100))
        else:
            tmp = tuple(dict_venue_id[k])
        res2.append({'id': k, 'name': n, 'paper': tmp})
        res3[k] = {'id': k, 'name': n, 'paper': tmp}
    print(len(res), len(res2))
    return res, res2, res3

In [None]:
def write_venue(save_base_dir, data_name, data, venue_info):
    tmp_base = os.path.join(save_base_dir, data_name, 'downstream', 'venue_recommendation')
    if not os.path.exists(tmp_base):
        os.makedirs(tmp_base)
    data_path = os.path.join(tmp_base, 'data.tsv')
    meta_data_path = os.path.join(tmp_base, 'venue.jsonl')
    with open(data_path, 'w') as f:
        for pname, vid in data:
            f.write("%s\t%s\n" % (pname, str(vid)))
    with open(meta_data_path, 'w') as f:
        for md in venue_info:
            f.write(json.dumps(md)+'\n')

In [None]:
def write_venue_advanced(save_base_dir, data_name, data, venue_info_dict):
    tmp_base = os.path.join(save_base_dir, data_name, 'downstream', 'venue_recommendation')
    if not os.path.exists(tmp_base):
        os.makedirs(tmp_base)
    
    # venue name matching
    data_vn_path = os.path.join(tmp_base, 'data_vn.jsonl')
    with open(data_vn_path, 'w') as f:
        for pname, vid in tqdm(data):
            dd = {'q_text': pname, 'k_text': venue_info_dict[vid]['name']}
            ddr = json.dumps(dd)
            f.write(ddr + '\n')
    
    
    # venue paper matching
    data_vp_path = os.path.join(tmp_base, 'data_vp.jsonl')
    with open(data_vp_path, 'w') as f:
        for pname, vid in tqdm(data):
            tmp_paper = list(venue_info_dict[vid]['paper'])
            random.shuffle(tmp_paper)
            dd = {'q_text': pname, 'k_text': ' '.join(tmp_paper)}
            ddr = json.dumps(dd)
            f.write(ddr + '\n')

In [None]:
d, v, vinfo = generate_venue_recommendation(data, vid2n)

In [None]:
write_venue(save_base_dir, data_name, d, v)

In [None]:
write_venue_advanced(save_base_dir, data_name, d, vinfo)

## regression task

In [None]:
def generate_regression(data, kw_name, func=None):
    res = set()
    for k in data:
        if kw_name in data[k]:
            v = data[k][kw_name]
            t = data[k]['title']
            if func is not None:
                v = func(v)
            res.add((t, v))
    return res

In [None]:
def write_regression(save_base_dir, data_name, task_name, data, theshold, minus=0):
    tmp_base = os.path.join(save_base_dir, data_name, 'downstream', task_name)
    if not os.path.exists(tmp_base):
        os.makedirs(tmp_base)
    data_path = os.path.join(tmp_base, 'data.jsonl')
    print("Write to %s" % data_path)
    with open(data_path, 'w') as f:
        for p, v in tqdm(data):
            if float(v) > theshold:
                continue
            f.write(json.dumps({"q_text":p, "label":float(v)-minus})+'\n')

In [None]:
cite_pred = generate_regression(data, 'citation', len)

In [None]:
year_pred = generate_regression(data, 'year')

In [None]:
## statistics of citation
citation_cnt_list = []
for tmp in cite_pred:
    citation_cnt_list.append(tmp[1])
sns.kdeplot(citation_cnt_list)

In [None]:
## statistics of citation
year_list = []
for tmp in year_pred:
    year_list.append(int(tmp[1]))
print(np.min([float(y) for y in year_list]))
sns.kdeplot(year_list)

In [None]:
write_regression(save_base_dir, data_name, 'citation_prediction', cite_pred, 100)

In [None]:
# !!! change the down minus threshold
write_regression(save_base_dir, data_name, 'year_prediction', year_pred, 20000, 1981.0)

## classification

In [None]:
def generate_classification(data, id2level, id2label):
    res = set()
    coarse_labels = []
    coarse_labels_dict = dict()
    coarse_labels_cnt = defaultdict(int)
    for id, lv in id2level.items():
        if lv == 1:
            coarse_labels.append(id)
    coarse_labels = sorted(coarse_labels)
    for idx_for_classification, lid in enumerate(coarse_labels):
        coarse_labels_dict[lid] = idx_for_classification
    print(len(coarse_labels_dict), coarse_labels_dict)
    for k in tqdm(data):
        if 'label' in data[k] and 'title' in data[k]:
            c_labels = data[k]['label'] 
            for lid in c_labels:
                lid = str(lid)
                if lid in id2level and lid in coarse_labels_dict:
                    res.add((data[k]['title'], coarse_labels_dict[lid]))
                    coarse_labels_cnt[lid] += 1
    print("Number of label %d" % len(coarse_labels_cnt))
    print("Label Mapping")
    print({id2label[x]: coarse_labels_dict[x] for x in coarse_labels_dict})
    print("Label Count")
    print({id2label[x]: coarse_labels_cnt[x] for x in coarse_labels_cnt})
    # sns.barplot(data=list(coarse_labels_cnt.values()))
    return res, {id2label[x]: coarse_labels_dict[x] for x in coarse_labels_dict}

In [None]:
def write_json(save_base_dir, data_name, task_name, data):
    tmp_base = os.path.join(save_base_dir, data_name, 'downstream', task_name)
    data_path = os.path.join(tmp_base, 'label.jsonl') 
    with open(data_path,'w') as fout:
        json.dump(data, fout, indent = 4)

In [None]:
def write_classification(save_base_dir, data_name, task_name, data):
    tmp_base = os.path.join(save_base_dir, data_name, 'downstream', task_name)
    if not os.path.exists(tmp_base):
        os.makedirs(tmp_base)
    data_path = os.path.join(tmp_base, 'data.jsonl')
    print("Write to %s" % data_path)
    with open(data_path, 'w') as f:
        for p, v in tqdm(data):
            #f.write('%s\t%s\n' % (p, str(v)))
            f.write(json.dumps({"q_text":p, "label":v})+'\n')

In [None]:
coarse_classification, id2label = generate_classification(data, id2level, id2label)

In [None]:
write_classification(save_base_dir, data_name, 'coarse_classification', coarse_classification)
write_json(save_base_dir, data_name, 'coarse_classification', id2label)

## retrieval

In [None]:
def init_retrieval_label(label_dir):
    # read label name dict
    label_name_dict = {}
    label_name_set = set()
    label_name2id_dict = {}

    with open(label_dir) as f:
        readin = f.readlines()
        for line in tqdm(readin):
            tmp = line.strip().split('\t')
            label_name_dict[tmp[0]] = tmp[1]
            label_name2id_dict[tmp[1]] = tmp[0]
            label_name_set.add(tmp[1])

    print(f'Num of unique labels:{len(label_name_set)}')
    
    return label_name_dict, label_name2id_dict

In [None]:
def write_retrieval_base(save_base_dir, data_name, data, label_name_dict, label_name2id_dict):
    # save node labels
    random.seed(0)

    tmp_base = os.path.join(save_base_dir, data_name, 'downstream', 'retrieval')
    if not os.path.exists(tmp_base):
        os.makedirs(tmp_base)
    data_path = os.path.join(tmp_base, 'node_classification.jsonl')
    
    with open(data_path, 'w') as fout:
        for q in tqdm(data):

            q_text = data[q]['title']

            label_names_list = list(set([label_name_dict[lid] for lid in data[q]['label']]))
            label_ids_list = [label_name2id_dict[lname] for lname in label_names_list]

            fout.write(json.dumps({
                'q_text':q_text,
                'labels':label_ids_list,
                'label_names':label_names_list
            })+'\n')

In [None]:
def write_retrieval(save_base_dir, data_name, data, label_name2id_dict):
    tmp_base = os.path.join(save_base_dir, data_name, 'downstream', 'retrieval')
    if not os.path.exists(tmp_base):
        os.makedirs(tmp_base)
    
    label_json_path = os.path.join(tmp_base, 'documents.json')
    print("Write to %s" % label_json_path)
    labels_dict = []
    for lname in label_name2id_dict:
        if lname != 'null':
            labels_dict.append({'id':label_name2id_dict[lname], 'contents':lname})
    json.dump(labels_dict, open(label_json_path, 'w'), indent=4)

    label_txt_path = os.path.join(tmp_base, 'documents.txt')
    print("Write to %s" % label_txt_path)
    with open(label_txt_path, 'w') as fout:
        for lname in label_name2id_dict:
            if lname == 'null':
                continue
            fout.write(label_name2id_dict[lname]+'\t'+lname+'\n')
            
    docid = 0
    data_class_path = os.path.join(tmp_base, 'node_classification.jsonl')
    print("Read from %s" % data_class_path)
    node_text_path = os.path.join(tmp_base, 'node_text.tsv')
    print("Write to %s" % node_text_path)
    trec_path = os.path.join(tmp_base, 'truth.trec')
    print("Write to %s" % trec_path)
    with open(data_class_path) as f, open(node_text_path, 'w') as fout1, open(trec_path, 'w') as fout2:
        readin = f.readlines()
        for line in tqdm(readin):
            tmp = json.loads(line)
            fout1.write(str(docid) + '\t' + tmp['q_text'] + '\n')
            for label in tmp['labels']:
                fout2.write(str(docid)+' '+str(0)+' '+label+' '+str(1)+'\n')
            docid += 1

In [None]:
label_dir = os.path.join(raw_base_dir, 'MAG', data_name, 'labels.txt')
label_name_dict, label_name2id_dict = init_retrieval_label(label_dir)
write_retrieval_base(save_base_dir, data_name, data, label_name_dict, label_name2id_dict)
write_retrieval(save_base_dir, data_name, data, label_name2id_dict)