In [None]:
import json
from collections import defaultdict
import re
import numpy as np
import random

In [None]:
with open('origin/train.jsonl') as f:
    train = f.readlines()
    train = list(map(json.loads, train))

with open('origin/test.jsonl') as f:
    test = f.readlines()
    test = list(map(json.loads, test))

In [None]:
# delete \n and \
for line in train:
    line['text'] = line['text'].replace('\n', ' ')
    line['text'] = line['text'].replace('\\', '')

for line in test:
    line['text'] = line['text'].replace('\n', ' ')
    line['text'] = line['text'].replace('\\', '')

In [None]:
# count classes
train_class_count = defaultdict(int)
test_class_count = defaultdict(int)

for line in train:
    train_class_count[line['label_text']] += 1

for line in test:
    test_class_count[line['label_text']] += 1

unseen_class_count = train_class_count.keys() - test_class_count.keys()

In [None]:
print('Number of train data:', len(train))
print('Number of test data:', len(test))

print('Number of classes in train data:', len(train_class_count))
print('Number of classes in test data:', len(test_class_count))
print('Number of unseen classes:', len(unseen_class_count))

print('Count of train label:\n', sorted(train_class_count.items(), key=lambda x: -x[1]))
print('Count of test label:\n', sorted(test_class_count.items(), key=lambda x: -x[1]))

In [None]:
# label2id
label2id = dict()

for line in train:
    if line['label_text'] not in label2id:
        label2id[line['label_text']] = line['label']

for line in test:
    if line['label_text'] not in label2id:
        label2id[line['label_text']] = line['label']

id2label = {v: k for k, v in label2id.items()}

In [None]:
# delete long tail data
label_collection = [[] for _ in range(len(train_class_count))]
for line in train:
    idx = label2id[line['label_text']]
    label_collection[idx].append(line)

threshold = int(np.percentile([len(t) for t in label_collection], 25))
label_collection = [t[:threshold] for t in label_collection if len(t) >= threshold]

In [None]:
# refresh label2id
label2id = dict()
for idx, piece in enumerate(label_collection):
    label2id[piece[0]['label_text']] = idx

id2label = {v: k for k, v in label2id.items()}

In [None]:
# delete unseen test
idx = 0
while idx < len(test):
    if test[idx]['label_text'] not in label2id:
        test.pop(idx)
    else:
        idx += 1

In [None]:
with open('id2label.json', 'w') as f:
    json.dump(id2label, f, ensure_ascii=False, indent=2)

with open('label2id.json', 'w') as f:
    json.dump(label2id, f, ensure_ascii=False, indent=2)

In [None]:
# shuffle train dataset
collection_pointer = [0] * len(label_collection)

shuffle_train = []
global_idx, idx = 0, 0
while global_idx < sum(len(t) for t in label_collection):
    if collection_pointer[idx] < len(label_collection[idx]):
        shuffle_train.append(label_collection[idx][collection_pointer[idx]])
        collection_pointer[idx] += 1
        idx = (idx + 1) % len(label_collection)
        global_idx += 1
    else:
        idx = (idx + 1) % len(label_collection) 

In [None]:
print('Number of train data:', len(shuffle_train))
print('Number of test data:', len(test))

In [None]:
# add label description
with open('label-description.jsonl') as f:
    label_desc = f.readlines()
    label_desc = list(map(json.loads, label_desc))
    label_desc = {line['label']: line['description'] for line in label_desc}

for line in shuffle_train:
    label_text = line.pop('label_text')
    line['label'] = label2id[label_text]
    line['label_description'] = label_desc[label_text]

for line in test:
    label_text = line.pop('label_text')
    line['label'] = label2id[label_text]
    line['label_description'] = label_desc[label_text]

with open('train.json', 'w') as f:
    json.dump(shuffle_train, f, ensure_ascii=False, indent=2)

with open('test.json', 'w') as f:
    json.dump(test, f, ensure_ascii=False, indent=2)

In [None]:
# add label description
with open('label-description-with-example.jsonl') as f:
    label_desc = f.readlines()
    label_desc = list(map(json.loads, label_desc))
    label_desc = {line['label']: line['description'] for line in label_desc}

for line in shuffle_train:
    label_text = id2label[line['label']]
    line['label_description'] = label_desc[label_text]

for line in test:
    label_text = id2label[line['label']]
    line['label_description'] = label_desc[label_text]

with open('train-with-example.json', 'w') as f:
    json.dump(shuffle_train, f, ensure_ascii=False, indent=2)

with open('test-with-example.json', 'w') as f:
    json.dump(test, f, ensure_ascii=False, indent=2)