In [20]:
import json
from collections import defaultdict
import re
import numpy as np
import random
from sklearn.model_selection import train_test_split

In [21]:
state = '50%*1'

train, test = [], []

with open('origin/text_train') as f:
    text_train = f.readlines()
with open('origin/label_train') as f:
    label_train = f.readlines()

for text, labels in zip(text_train, label_train):
    train.append({
        'text': text.strip(),
        'labels': labels.strip().split()
    })

with open('origin/text_test') as f:
    text_test = f.readlines()
with open('origin/label_test') as f:
    label_test = f.readlines()

for text, labels in zip(text_test, label_test):
    test.append({
        'text': text.strip(),
        'labels': labels.strip().split()
    })

with open('origin/text_val') as f:
    text_val = f.readlines()
with open('origin/label_val') as f:
    label_val = f.readlines()

for text, labels in zip(text_val, label_val):
    test.append({
        'text': text.strip(),
        'labels': labels.strip().split()
    })

In [22]:
# delete \n and \
for line in train:
    line['text'] = line['text'].replace('\n', ' ')
    line['text'] = line['text'].replace('\\', '')

for line in test:
    line['text'] = line['text'].replace('\n', ' ')
    line['text'] = line['text'].replace('\\', '')

In [23]:
# count classes
train_class_count = defaultdict(int)
test_class_count = defaultdict(int)

for line in train:
    for label in line['labels']:
        train_class_count[label] += 1

for line in test:
    for label in line['labels']:
        test_class_count[label] += 1

unseen_class_count = train_class_count.keys() - test_class_count.keys()

In [24]:
print('Number of train data:', len(train))
print('Number of test data:', len(test))

print('Number of classes in train data:', len(train_class_count))
print('Number of classes in test data:', len(test_class_count))
print('Number of unseen classes:', len(unseen_class_count))

print('Count of train label:\n', sorted(train_class_count.items(), key=lambda x: -x[1]))
print('Count of test label:\n', sorted(test_class_count.items(), key=lambda x: -x[1]))

Number of train data: 53840
Number of test data: 2000
Number of classes in train data: 54
Number of classes in test data: 54
Number of unseen classes: 0
Count of train label:
 [('cs.IT', 17152), ('math.IT', 17152), ('cs.LG', 8007), ('cs.AI', 4991), ('stat.ML', 4939), ('cs.DS', 4673), ('cs.DM', 4074), ('cs.SI', 4003), ('cs.LO', 3517), ('math.CO', 3355), ('physics.soc-ph', 3329), ('cs.NI', 3081), ('cs.CC', 3079), ('math.OC', 3043), ('cs.CL', 2831), ('cs.CV', 2774), ('cs.CR', 2723), ('cs.DC', 2259), ('cs.SY', 2242), ('cs.NE', 2024), ('cs.IR', 2018), ('cs.GT', 1720), ('quant-ph', 1581), ('cs.CY', 1572), ('cs.PL', 1382), ('cs.DB', 1267), ('cs.SE', 1261), ('math.PR', 1194), ('cs.CG', 1080), ('cs.NA', 1052), ('cs.HC', 1005), ('cs.MA', 956), ('math.NA', 945), ('cs.CE', 923), ('cs.RO', 921), ('cs.FL', 915), ('math.ST', 875), ('stat.TH', 875), ('cs.DL', 867), ('cmp-lg', 856), ('cs.MM', 711), ('cs.PF', 687), ('math.LO', 667), ('cond-mat.stat-mech', 664), ('stat.AP', 628), ('stat.ME', 503), ('cs.M

In [25]:
train_text_length, test_text_length = 0, 0
for line in train:
    train_text_length += len(line['text'])
for line in test:
    test_text_length += len(line['text'])

print('Avg length of train text:', train_text_length / len(train))
print('Avg length of test text:', test_text_length / len(test))

Avg length of train text: 984.6816121842496
Avg length of test text: 981.3435


In [26]:
# label2id
label2id = dict()
cnt = 0

for line in train:
    for label in line['labels']:
        if label not in label2id:
            label2id[label] = cnt
            cnt += 1

for line in test:
    for label in line['labels']:
        if label not in label2id:
            label2id[label] = cnt
            cnt += 1

id2label = {v: k for k, v in label2id.items()}

In [27]:
# to one label, labels[0] is the rarest one
for line in train:
    minn, minn_label = 20000, ''
    for label in line['labels']:
        if train_class_count[label] < minn:
            minn = train_class_count[label]
            minn_label = label
    if line['labels'][0] != minn_label:
        line['labels'].remove(minn_label)
        line['labels'].insert(0, minn_label)

In [28]:
# delete long tail data
label_collection = [[] for _ in range(len(train_class_count))]
for line in train:
    idx = label2id[line['labels'][0]]
    label_collection[idx].append(line)

if state == '50%*1':
    threshold = int(np.percentile([len(t) for t in label_collection], 50))
    label_collection = [t[:threshold * 1] for t in label_collection if len(t) >= threshold]
elif state == '25%*3':
    threshold = int(np.percentile([len(t) for t in label_collection], 25))
    label_collection = [t[:threshold * 3] for t in label_collection if len(t) >= threshold]

In [29]:
# refresh label2id
label2id = dict()
for idx, piece in enumerate(label_collection):
    label2id[piece[0]['labels'][0]] = idx

id2label = {v: k for k, v in label2id.items()}

In [30]:
# delete unseen test
idx = 0
while idx < len(test):
    for label in test[idx]['labels']:
        if label not in label2id:
            test.pop(idx)
            break
    else:
        idx += 1

In [31]:
with open(f'{state}/id2label.json', 'w') as f:
    json.dump(id2label, f, ensure_ascii=False, indent=2)

with open(f'{state}/label2id.json', 'w') as f:
    json.dump(label2id, f, ensure_ascii=False, indent=2)

In [32]:
# shuffle train dataset
collection_pointer = [0] * len(label_collection)

shuffle_train = []
global_idx, idx = 0, 0
while global_idx < sum(len(t) for t in label_collection):
    if collection_pointer[idx] < len(label_collection[idx]):
        shuffle_train.append(label_collection[idx][collection_pointer[idx]])
        collection_pointer[idx] += 1
        idx = (idx + 1) % len(label_collection)
        global_idx += 1
    else:
        idx = (idx + 1) % len(label_collection) 

In [33]:
# subset of test
test = test[:len(shuffle_train) * 3]

In [34]:
print('Number of train data:', len(shuffle_train))
print('Number of test data:', len(test))

Number of train data: 20385
Number of test data: 312


In [35]:
# delete unseen train labels
for line in shuffle_train:
    idx = 0
    while idx < len(line['labels']):
        if line['labels'][idx] not in label2id:
            line['labels'].pop(idx)
        else:
            idx += 1

In [36]:
# count classes
train_class_count = defaultdict(int)
test_class_count = defaultdict(int)

for line in shuffle_train:
    for label in line['labels']:
        train_class_count[label] += 1

for line in test:
    for label in line['labels']:
        test_class_count[label] += 1

unseen_class_count = train_class_count.keys() - test_class_count.keys()
print('Number of train data:', len(shuffle_train))
print('Number of test data:', len(test))

print('Number of classes in train data:', len(train_class_count))
print('Number of classes in test data:', len(test_class_count))
print('Number of unseen classes:', len(unseen_class_count))

print('Count of train label:\n', sorted(train_class_count.items(), key=lambda x: -x[1]))
print('Count of test label:\n', sorted(test_class_count.items(), key=lambda x: -x[1]))

train_text_length, test_text_length = 0, 0
for line in shuffle_train:
    train_text_length += len(line['text'])
for line in test:
    test_text_length += len(line['text'])

print('Avg length of train text:', train_text_length / len(shuffle_train))
print('Avg length of test text:', test_text_length / len(test))

Number of train data: 20385
Number of test data: 312
Number of classes in train data: 27
Number of classes in test data: 26
Number of unseen classes: 1
Count of train label:
 [('cs.IT', 3110), ('cs.CL', 2190), ('math.OC', 1793), ('stat.ML', 1768), ('cs.NI', 1704), ('cs.CC', 1564), ('math.CO', 1542), ('cs.CR', 1520), ('cs.CV', 1388), ('cs.DC', 1315), ('physics.soc-ph', 1259), ('cs.SY', 1252), ('cs.CY', 1153), ('cs.GT', 1087), ('cs.PL', 1084), ('cs.IR', 1051), ('cs.SE', 918), ('cs.NE', 907), ('cs.DB', 845), ('quant-ph', 832), ('cs.MA', 825), ('cs.HC', 819), ('cs.CG', 796), ('math.PR', 785), ('cs.RO', 759), ('cs.FL', 755), ('cmp-lg', 755)]
Count of test label:
 [('cs.CL', 53), ('cs.CR', 50), ('cs.SY', 45), ('math.OC', 42), ('cs.NI', 35), ('cs.CC', 34), ('cmp-lg', 33), ('cs.CY', 33), ('cs.CV', 31), ('cs.IR', 28), ('cs.DC', 25), ('quant-ph', 25), ('cs.PL', 25), ('cs.NE', 25), ('cs.SE', 24), ('cs.GT', 21), ('cs.DB', 21), ('cs.FL', 20), ('cs.HC', 20), ('cs.MA', 17), ('cs.RO', 17), ('math.CO',

In [37]:
# add label description
with open('label-description.jsonl') as f:
    label_desc = f.readlines()
    label_desc = list(map(json.loads, label_desc))
    label_desc = {line['label']: line['description'] for line in label_desc}

for line in shuffle_train:
    label_text = line['labels'][0]
    for i, label in enumerate(line['labels']):
        line['labels'][i] = label2id[label]
    line['label_description'] = label_desc[label_text]

for line in test:
    label_text = line['labels'][0]
    for i, label in enumerate(line['labels']):
        line['labels'][i] = label2id[label]
    line['label_description'] = label_desc[label_text]

with open(f'{state}/train.json', 'w') as f:
    json.dump(shuffle_train, f, ensure_ascii=False, indent=2)

with open(f'{state}/test.json', 'w') as f:
    json.dump(test, f, ensure_ascii=False, indent=2)

In [38]:
# add label description
with open('label-description-with-example.jsonl') as f:
    label_desc = f.readlines()
    label_desc = list(map(json.loads, label_desc))
    label_desc = {line['label']: line['description'] for line in label_desc}

for line in shuffle_train:
    label_text = id2label[line['labels'][0]]
    line['label_description'] = label_desc[label_text]

for line in test:
    label_text = id2label[line['labels'][0]]
    line['label_description'] = label_desc[label_text]

with open(f'{state}/train-with-example.json', 'w') as f:
    json.dump(shuffle_train, f, ensure_ascii=False, indent=2)

with open(f'{state}/test-with-example.json', 'w') as f:
    json.dump(test, f, ensure_ascii=False, indent=2)