In [1]:
import os
import random
import json
import pickle
from copy import deepcopy
from tqdm import tqdm
from collections import defaultdict

import numpy as np
from transformers import BertTokenizerFast

In [2]:
random.seed(0)

In [3]:
dataset = 'MAG' 
sub_dataset='Mathematics'

# Generate Pretraining Data

In [4]:
# read raw data
with open(f'data_dir/MAG/{sub_dataset}/papers_bert.json') as f:
    data = {}
    readin = f.readlines()
    for line in tqdm(readin):
        #data.append(json.loads(line))
        #data.append(eval(line.strip()))
        tmp = eval(line.strip())
        data[tmp['paper']] = tmp
#random.shuffle(data)

100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 490551/490551 [00:51<00:00, 9545.93it/s]


In [5]:
# read label name dict
label_name_dict = {}
label_name_set = set()
label_name2id_dict = {}

with open(f'data_dir/MAG/{sub_dataset}/labels.txt') as f:
    readin = f.readlines()
    for line in tqdm(readin):
        tmp = line.strip().split('\t')
        label_name_dict[tmp[0]] = tmp[1]
        label_name2id_dict[tmp[1]] = tmp[0]
        label_name_set.add(tmp[1])

print(f'Num of unique labels:{len(label_name_set)}')

100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 14271/14271 [00:00<00:00, 805773.88it/s]

Num of unique labels:14010





In [7]:
# filter related

idd_set = set(list(data.keys()))

for idd in tqdm(data):
    if 'reference' not in data[idd] or len(data[idd]['reference']) == 0:
        continue
        
    data[idd]['reference'] = list(set(data[idd]['reference']) & idd_set)

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 490551/490551 [00:03<00:00, 147930.32it/s]


In [8]:
# text processing function
def text_process(text):
    p_text = ' '.join(text.split('\r\n'))
    p_text = ' '.join(p_text.split('\n\r'))
    p_text = ' '.join(p_text.split('\n'))
    p_text = ' '.join(p_text.split('\t'))
    p_text = ' '.join(p_text.split('\rm'))
    p_text = ' '.join(p_text.split('\r'))
    p_text = ''.join(p_text.split('$'))
    p_text = ''.join(p_text.split('*'))

    return p_text

In [9]:
# average edge

ref_cnt = 0
ref_paper = {}


for idd in tqdm(data):
    if 'reference' not in data[idd] or len(data[idd]['reference']) == 0:
        continue
        
    ref_cnt += len(data[idd]['reference'])
    ref_paper[idd] = data[idd]

print(f'avg ref cnt:{ref_cnt/len(ref_paper)}.')
print(f'ref papers:{len(ref_paper)}')

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 490551/490551 [00:00<00:00, 754943.51it/s]

avg ref cnt:6.276238238575368.
ref papers:380056





In [10]:
## split train/val/test as 8:1:1

random.seed(0)

train_pairs = []
val_pairs = []
test_pairs = []
train_pair_set = set()
item_id2idx = {}
train_neighbor = defaultdict(list)
val_neighbor = defaultdict(list)
test_neighbor = defaultdict(list)

for iid in tqdm(ref_paper):
    if iid not in item_id2idx:
        item_id2idx[iid] = len(item_id2idx)
    
    also_viewed = ref_paper[iid]['reference']
    random.shuffle(also_viewed)
    
    for i in range(int(len(also_viewed)*0.8)):
        train_pairs.append((iid,also_viewed[i]))
        train_pair_set.add((iid,also_viewed[i]))
        train_pair_set.add((also_viewed[i],iid))
        
        # add to item_id2idx
        if also_viewed[i] not in item_id2idx:
            item_id2idx[also_viewed[i]] = len(item_id2idx)

        # add to train_user_neighbor/train_item_neighbor
        train_neighbor[iid].append(also_viewed[i])

    for i in range(int(len(also_viewed)*0.8),int(len(also_viewed)*0.9)):
        if (iid,also_viewed[i]) in train_pair_set:
            continue
        val_pairs.append((iid,also_viewed[i]))
        assert (iid,also_viewed[i]) not in train_pair_set

        # add to item_id2idx
        if also_viewed[i] not in item_id2idx:
            item_id2idx[also_viewed[i]] = len(item_id2idx)
        
        # add to train_user_neighbor/train_item_neighbor
        val_neighbor[iid].append(also_viewed[i])
        
    for i in range(int(len(also_viewed)*0.9),len(also_viewed)):
        if (iid,also_viewed[i]) in train_pair_set:
            continue
        test_pairs.append((iid,also_viewed[i]))
        assert (iid,also_viewed[i]) not in train_pair_set
        
        # add to item_id2idx
        if also_viewed[i] not in item_id2idx:
            item_id2idx[also_viewed[i]] = len(item_id2idx)
        
        # add to train_user_neighbor/train_item_neighbor
        test_neighbor[iid].append(also_viewed[i])
        
print(f'Train/Val/Test size:{len(train_pairs)},{len(val_pairs)},{len(test_pairs)}')
print(f'Train/Val/Test avg:{len(train_pairs)/len(ref_paper)},{len(val_pairs)/len(ref_paper)},{len(test_pairs)/len(ref_paper)}')

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 380056/380056 [00:10<00:00, 36588.69it/s]

Train/Val/Test size:1733292,194157,455459
Train/Val/Test avg:4.560622645083883,0.5108641884353885,1.1983997095164922





In [14]:
# save all the text on node in the graph

node_id_set = set()

with open(f'data_dir/{dataset}/{sub_dataset}/corpus.txt','w') as fout:    
    for iid in tqdm(ref_paper):
        also_viewed = ref_paper[iid]['reference']
        
        # save iid text
        if iid not in node_id_set:
            node_id_set.add(iid)
            fout.write(iid+'\t'+text_process(data[iid]['title'])+'\n')
    
        # save neighbor
        for iid_n in also_viewed:
            if iid_n not in node_id_set:
                node_id_set.add(iid_n)
                fout.write(iid_n+'\t'+text_process(data[iid_n]['title'])+'\n')

100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 47099/47099 [00:00<00:00, 100380.96it/s]


In [15]:
sample_neighbor_num = 5

In [16]:
# generate and save train file

random.seed(0)
sample_neighbor_num = 5

with open(f'data_dir/{dataset}/{sub_dataset}/train.text.jsonl','w') as fout:
    for (q, k) in tqdm(train_pairs):
        
        # prepare sample pool for item
        q_n_pool = set(deepcopy(train_neighbor[q]))
        k_n_pool = set(deepcopy(train_neighbor[k]))

        if k in q_n_pool:
            q_n_pool.remove(k)
        if q in k_n_pool:
            k_n_pool.remove(q)

        q_n_pool = list(q_n_pool)
        k_n_pool = list(k_n_pool)
        random.shuffle(q_n_pool)
        random.shuffle(k_n_pool)
        
        # sample neighbor
        if len(q_n_pool) >= sample_neighbor_num:
            q_samples = q_n_pool[:sample_neighbor_num]
        else:
            q_samples = q_n_pool + [-1] * (sample_neighbor_num-len(q_n_pool))
        
        if len(k_n_pool) >= sample_neighbor_num:
            k_samples = k_n_pool[:sample_neighbor_num]
        else:
            k_samples = k_n_pool + [-1] * (sample_neighbor_num-len(k_n_pool))
        
        # prepare for writing file
        q_text = text_process(data[q]['title'])
        #q_n_text = '\*\*'.join([text_process(data[q_n]['title']) if q_n != -1 else '' for q_n in q_samples])
        q_n_text = [text_process(data[q_n]['title']) if q_n != -1 else '' for q_n in q_samples]
        
        k_text = text_process(data[k]['title'])
        #k_n_text = '\*\*'.join([text_process(data[k_n]['title']) if k_n != -1 else '' for k_n in k_samples])
        k_n_text = [text_process(data[k_n]['title']) if k_n != -1 else '' for k_n in k_samples]
        
        #q_line = q_text + '\t' + q_n_text
        #k_line = k_text + '\t' + k_n_text
        
        #fout.write(q_line+'\t'+k_line+'\n')
        fout.write(json.dumps({
            'q_text':q_text,
            'q_n_text':q_n_text,
            'k_text':k_text,
            'k_n_text':k_n_text,
        })+'\n')

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 191151/191151 [00:12<00:00, 15449.70it/s]


In [17]:
# generate and save val file (make sure to delete items that are not in train set)

random.seed(0)

with open(f'data_dir/{dataset}/{sub_dataset}/val.text.jsonl','w') as fout:
    for (q, k) in tqdm(val_pairs):
        
        # prepare sample pool for item
        q_n_pool = set(deepcopy(train_neighbor[q]))
        k_n_pool = set(deepcopy(train_neighbor[k]))

        if k in q_n_pool:
            q_n_pool.remove(k)
        if q in k_n_pool:
            k_n_pool.remove(q)

        q_n_pool = list(q_n_pool)
        k_n_pool = list(k_n_pool)
        random.shuffle(q_n_pool)
        random.shuffle(k_n_pool)
        
        # sample neighbor
        if len(q_n_pool) >= sample_neighbor_num:
            q_samples = q_n_pool[:sample_neighbor_num]
        else:
            q_samples = q_n_pool + [-1] * (sample_neighbor_num-len(q_n_pool))
        
        if len(k_n_pool) >= sample_neighbor_num:
            k_samples = k_n_pool[:sample_neighbor_num]
        else:
            k_samples = k_n_pool + [-1] * (sample_neighbor_num-len(k_n_pool))
        
        # prepare for writing file
        q_text = text_process(data[q]['title'])
        #q_n_text = '\*\*'.join([text_process(data[q_n]['title']) if q_n != -1 else '' for q_n in q_samples])
        q_n_text = [text_process(data[q_n]['title']) if q_n != -1 else '' for q_n in q_samples]
        
        k_text = text_process(data[k]['title'])
        #k_n_text = '\*\*'.join([text_process(data[k_n]['title']) if k_n != -1 else '' for k_n in k_samples])
        k_n_text = [text_process(data[k_n]['title']) if k_n != -1 else '' for k_n in k_samples]
        
        #q_line = q_text + '\t' + q_n_text
        #k_line = k_text + '\t' + k_n_text
        
        #fout.write(q_line+'\t'+k_line+'\n')
        fout.write(json.dumps({
            'q_text':q_text,
            'q_n_text':q_n_text,
            'k_text':k_text,
            'k_n_text':k_n_text,
        })+'\n')

100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 20836/20836 [00:01<00:00, 11789.29it/s]


In [18]:
# generate and save test file (make sure to delete items that are not in train set)

random.seed(0)

with open(f'data_dir/{dataset}/{sub_dataset}/test.text.jsonl','w') as fout:
    for (q, k) in tqdm(test_pairs):
        
        # prepare sample pool for item
        q_n_pool = set(deepcopy(train_neighbor[q]))
        k_n_pool = set(deepcopy(train_neighbor[k]))

        if k in q_n_pool:
            q_n_pool.remove(k)
        if q in k_n_pool:
            k_n_pool.remove(q)

        q_n_pool = list(q_n_pool)
        k_n_pool = list(k_n_pool)
        random.shuffle(q_n_pool)
        random.shuffle(k_n_pool)
        
        # sample neighbor
        if len(q_n_pool) >= sample_neighbor_num:
            q_samples = q_n_pool[:sample_neighbor_num]
        else:
            q_samples = q_n_pool + [-1] * (sample_neighbor_num-len(q_n_pool))
        
        if len(k_n_pool) >= sample_neighbor_num:
            k_samples = k_n_pool[:sample_neighbor_num]
        else:
            k_samples = k_n_pool + [-1] * (sample_neighbor_num-len(k_n_pool))
        
        # prepare for writing file
        q_text = text_process(data[q]['title'])
        #q_n_text = '\*\*'.join([text_process(data[q_n]['title']) if q_n != -1 else '' for q_n in q_samples])
        q_n_text = [text_process(data[q_n]['title']) if q_n != -1 else '' for q_n in q_samples]
        
        k_text = text_process(data[k]['title'])
        #k_n_text = '\*\*'.join([text_process(data[k_n]['title']) if k_n != -1 else '' for k_n in k_samples])
        k_n_text = [text_process(data[k_n]['title']) if k_n != -1 else '' for k_n in k_samples]
        
        #q_line = q_text + '\t' + q_n_text
        #k_line = k_text + '\t' + k_n_text
        
        #fout.write(q_line+'\t'+k_line+'\n')
        fout.write(json.dumps({
            'q_text':q_text,
            'q_n_text':q_n_text,
            'k_text':k_text,
            'k_n_text':k_n_text,
        })+'\n')

100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 55414/55414 [00:03<00:00, 15958.06it/s]


In [19]:
# save side files
pickle.dump([sample_neighbor_num],open(f'data_dir/{dataset}/{sub_dataset}/neighbor_sampling.pkl','wb'))

In [20]:
# save neighbor file
pickle.dump(train_neighbor,open(f'data_dir/{dataset}/{sub_dataset}/neighbor/train_neighbor.pkl','wb'))
pickle.dump(val_neighbor,open(f'data_dir/{dataset}/{sub_dataset}/neighbor/val_neighbor.pkl','wb'))
pickle.dump(test_neighbor,open(f'data_dir/{dataset}/{sub_dataset}/neighbor/test_neighbor.pkl','wb'))

In [21]:
# save node labels
random.seed(0)

with open(f'data_dir/{dataset}/{sub_dataset}/nc/node_classification.jsonl','w') as fout:
    for q in tqdm(ref_paper):
        
        # prepare sample pool for item
        q_n_pool = set(deepcopy(train_neighbor[q]))

        q_n_pool = list(q_n_pool)
        random.shuffle(q_n_pool)
        
        # sample neighbor
        if len(q_n_pool) >= sample_neighbor_num:
            q_samples = q_n_pool[:sample_neighbor_num]
        else:
            q_samples = q_n_pool + [-1] * (sample_neighbor_num-len(q_n_pool))
        
        # prepare for writing file
        q_text = text_process(data[q]['title'])
        #q_n_text = '\*\*'.join([text_process(data[q_n]['title']) if q_n != -1 else '' for q_n in q_samples])
        q_n_text = [text_process(data[q_n]['title']) if q_n != -1 else '' for q_n in q_samples]
        
        label_names_list = list(set([label_name_dict[lid] for lid in ref_paper[q]['label']]))
        label_ids_list = [label_name2id_dict[lname] for lname in label_names_list]
        
        fout.write(json.dumps({
            'q_text':q_text,
            'q_n_text':q_n_text,
            'labels':label_ids_list,
            'label_names':label_names_list
        })+'\n')

100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 47099/47099 [00:01<00:00, 23675.75it/s]


In [23]:
# generate self constrastive pretraining

corpus_list = []

with open(f'data_dir/{dataset}/{sub_dataset}/corpus.txt') as f:
    readin = f.readlines()
    for line in tqdm(readin):
        tmp = line.strip().split('\t')
        corpus_list.append(tmp[1])
        
with open(f'data_dir/{dataset}/{sub_dataset}/self-train/train.text.jsonl','w') as fout:
    for dd in tqdm(corpus_list):
        fout.write(json.dumps({
            'q_text':dd,
            'q_n_text':[''],
            'k_text':dd,
            'k_n_text':[''],
        })+'\n')

with open(f'data_dir/{dataset}/{sub_dataset}/self-train/val.text.jsonl','w') as fout:
    for dd in tqdm(corpus_list[:int(0.2*len(corpus_list))]):
        fout.write(json.dumps({
            'q_text':dd,
            'q_n_text':[''],
            'k_text':dd,
            'k_n_text':[''],
        })+'\n')
        
with open(f'data_dir/{dataset}/{sub_dataset}/self-train/test.text.jsonl','w') as fout:
    for dd in tqdm(corpus_list[int(0.8*len(corpus_list)):]):
        fout.write(json.dumps({
            'q_text':dd,
            'q_n_text':[''],
            'k_text':dd,
            'k_n_text':[''],
        })+'\n')

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 58186/58186 [00:00<00:00, 1124803.65it/s]
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 58186/58186 [00:00<00:00, 84601.59it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 11637/11637 [00:00<00:00, 134216.71it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 11638/11638 [00:00<00:00, 132842.68it/s]


## Generate node classification data for retrieval and reranking

In [24]:
# write labels into documents.json

labels_dict = []
#for lid in label_name_dict:
for lname in label_name2id_dict:
    if lname != 'null':
        labels_dict.append({'id':label_name2id_dict[lname], 'contents':lname})
json.dump(labels_dict, open(f'data_dir/MAG/{sub_dataset}/nc/documents.json', 'w'), indent=4)

with open(f'data_dir/MAG/{sub_dataset}/nc/documents.txt', 'w') as fout:
    #for lid in label_name_dict:
    for lname in label_name2id_dict:
        if lname == 'null':
            continue
        fout.write(label_name2id_dict[lname]+'\t'+lname+'\n')

In [25]:
# generate node query file & ground truth file

docid = 0

with open(f'data_dir/MAG/{sub_dataset}/nc/node_classification.jsonl') as f, open(f'data_dir/MAG/{sub_dataset}/nc/node_text.tsv', 'w') as fout1, open(f'data_dir/MAG/{sub_dataset}/nc/truth.trec', 'w') as fout2:
    readin = f.readlines()
    for line in tqdm(readin):
        tmp = json.loads(line)
        fout1.write(str(docid) + '\t' + tmp['q_text'] + '\n')
        for label in tmp['labels']:
            fout2.write(str(docid)+' '+str(0)+' '+label+' '+str(1)+'\n')
        docid += 1

100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 47099/47099 [00:00<00:00, 86122.53it/s]


In [26]:
# generate node query file & ground truth file

docid = 0

with open(f'data_dir/MAG/{sub_dataset}/nc/node_classification.jsonl') as f, open(f'data_dir/MAG/{sub_dataset}/nc/train.text.jsonl', 'w') as fout1, open(f'data_dir/MAG/{sub_dataset}/nc/val.text.jsonl', 'w') as fout2, open(f'data_dir/MAG/{sub_dataset}/nc/test.truth.trec', 'w') as fout3, open(f'data_dir/MAG/{sub_dataset}/nc/test.node.text.jsonl', 'w') as fout4:
    readin = f.readlines()
    total_len = len(readin)
    for line in tqdm(readin[:int(0.8*total_len)]):
        tmp = json.loads(line)
        for label_name in tmp['label_names']:
            fout1.write(json.dumps({
                'q_text':tmp['q_text'],
                'q_n_text':tmp['q_n_text'],
                'k_text':label_name,
                'k_n_text':[''],
            })+'\n')
        docid += 1
    
    for line in tqdm(readin[int(0.8*total_len):int(0.9*total_len)]):
        tmp = json.loads(line)
        for label_name in tmp['label_names']:
            fout2.write(json.dumps({
                'q_text':tmp['q_text'],
                'q_n_text':tmp['q_n_text'],
                'k_text':label_name,
                'k_n_text':[''],
            })+'\n')
        docid += 1
        
    for line in tqdm(readin[int(0.9*total_len):]):
        tmp = json.loads(line)
        #fout4.write(str(docid) + '\t' + tmp['q_text'] + '\n')
        fout4.write(json.dumps({
                'id': str(docid),
                'text':tmp['q_text'],
                'n_text':tmp['q_n_text']
            })+'\n')
        for label in tmp['labels']:
            fout3.write(str(docid)+' '+str(0)+' '+label+' '+str(1)+'\n')
        docid += 1

100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 37679/37679 [00:01<00:00, 27108.89it/s]
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4710/4710 [00:00<00:00, 33668.52it/s]
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4710/4710 [00:00<00:00, 48783.26it/s]


## Generate Coarse-grained Classification Data

In [4]:
# read label name dict
coarse_label_id2name = {}
#coarse_label_id2idx = {}

with open(f'data_dir/MAG/{sub_dataset}/labels.txt') as f:
    readin = f.readlines()
    for line in tqdm(readin):
        tmp = line.strip().split('\t')
        if tmp[2] == '1':
            coarse_label_id2name[tmp[0]] = tmp[1]
            #coarse_label_id2idx[tmp[0]] = len(coarse_label_id2idx)

print(f'Num of unique labels:{len(coarse_label_id2name)};{coarse_label_id2name}')

100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5205/5205 [00:00<00:00, 1049735.65it/s]

Num of unique labels:40;{'40700': 'industrial organization', '50522688': 'economic growth', '187736073': 'management', '138921699': 'political economy', '134560507': 'environmental economics', '26271046': 'economic geography', '21547014': 'operations management', '45355965': 'socioeconomics', '106159729': 'financial economics', '47768531': 'development economics', '121955636': 'accounting', '18547055': 'international economics', '149782125': 'econometrics', '34447519': 'market economy', '145236788': 'labour economics', '556758197': 'monetary economics', '105639569': 'economic policy', '48824518': 'agricultural economics', '139719470': 'macroeconomics', '539667460': 'management science', '167562979': 'classical economics', '133425853': 'neoclassical economics', '175444787': 'microeconomics', '107826830': 'environmental resource management', '10138342': 'finance', '6303427': 'economic history', '4249254': 'demographic economics', '73283319': 'financial system', '144237770': 'mathematical




### Take care here, you need to generate data for 8 & 16 respectively.

In [5]:
# generate train/val/test file
# filter out and only use node which has single label

ktrain = 8 # train sample threshold, how many training samples do we have for each class
kdev = 8 # dev sample threshold, how many dev samples do we have for each class
label_samples = defaultdict(list)

with open(f'data_dir/MAG/{sub_dataset}/nc/node_classification.jsonl') as f:
    readin = f.readlines()
    for line in tqdm(readin):
        tmp = json.loads(line)
        inter_label = list(set(tmp['labels']) & set(coarse_label_id2name))
        if len(inter_label) == 1:
            label_samples[inter_label[0]].append(tmp)
            
# select labels
coarse_label_id2idx = {}
for l in label_samples:
    if len(label_samples[l]) > ktrain + kdev:
        coarse_label_id2idx[l] = len(coarse_label_id2idx)
        
print(f'Num of unique labels:{len(coarse_label_id2idx)};{coarse_label_id2idx}')

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 145761/145761 [00:02<00:00, 52456.28it/s]

Num of unique labels:40;{'144237770': 0, '145236788': 1, '10138342': 2, '149782125': 3, '139719470': 4, '175444787': 5, '50522688': 6, '106159729': 7, '100001284': 8, '190253527': 9, '136264566': 10, '556758197': 11, '21547014': 12, '162118730': 13, '40700': 14, '138921699': 15, '54750564': 16, '45355965': 17, '73283319': 18, '121955636': 19, '47768531': 20, '155202549': 21, '165556158': 22, '118084267': 23, '48824518': 24, '18547055': 25, '4249254': 26, '133425853': 27, '175605778': 28, '26271046': 29, '105639569': 30, '34447519': 31, '134560507': 32, '167562979': 33, '539667460': 34, '187736073': 35, '549774020': 36, '74363100': 37, '107826830': 38, '6303427': 39}





In [34]:
# save

if not os.path.exists(f'data_dir/MAG/{sub_dataset}/nc-coarse/{str(ktrain)}_{str(kdev)}'):
    os.mkdir(f'data_dir/MAG/{sub_dataset}/nc-coarse/{str(ktrain)}_{str(kdev)}')

with open(f'data_dir/MAG/{sub_dataset}/nc-coarse/{str(ktrain)}_{str(kdev)}/train.text.jsonl', 'w') as fout1, open(f'data_dir/MAG/{sub_dataset}/nc-coarse/{str(ktrain)}_{str(kdev)}/val.text.jsonl', 'w') as fout2, open(f'data_dir/MAG/{sub_dataset}/nc-coarse/{str(ktrain)}_{str(kdev)}/test.text.jsonl', 'w') as fout3:
    
    assert ktrain+kdev <= 32
    
    for l in coarse_label_id2idx:
        train_data = label_samples[l][:ktrain]
        dev_data = label_samples[l][ktrain:(ktrain+kdev)]
        #test_data = label_samples[l][(ktrain+kdev):]
        test_data = label_samples[l][32:]
    
        # write train
        for d in train_data:
            fout1.write(json.dumps({
                'q_text':d['q_text'],
                'q_n_text':d['q_n_text'],
                'label':coarse_label_id2idx[l]
            })+'\n')
    
        # write dev
        for d in dev_data:
            fout2.write(json.dumps({
                'q_text':d['q_text'],
                'q_n_text':d['q_n_text'],
                'label':coarse_label_id2idx[l]
            })+'\n')
    
        # write test
        for d in test_data:
            fout3.write(json.dumps({
                'q_text':d['q_text'],
                'q_n_text':d['q_n_text'],
                'label':coarse_label_id2idx[l]
            })+'\n')

pickle.dump(coarse_label_id2idx, open(f'data_dir/MAG/{sub_dataset}/nc-coarse/coarse_label_id2idx.pkl', 'wb'))
pickle.dump([ktrain, kdev], open(f'data_dir/MAG/{sub_dataset}/nc-coarse/threshold.pkl', 'wb'))