In [1]:
import json
from collections import defaultdict
import re
import numpy as np
import random
from sklearn.model_selection import train_test_split

In [2]:
state = '50%*1'

file_names = ["reut2-" + "%03d" % i + ".sgm" for i in range(22)]
data = []
for file_name in file_names:
    try:
        with open(f'origin/{file_name}') as f:
            file = f.readlines()
    except:
        with open(f'origin/{file_name}', 'rb') as f:
            file = f.readlines()
            file = list(map(lambda x: x.decode('utf-8', 'ignore'), file))
            
    file = ''.join(file)
    file = re.findall('<REUTERS.*?</REUTERS>', file, re.DOTALL)
    for line in file:
        topics_raw = re.findall('<TOPICS>(<D>.*</D>)</TOPICS>', line)
        if topics_raw:
            topics = re.findall('<D>(.*?)</D>', topics_raw[0])
            text = re.findall('<BODY>(.*?)</BODY>', line.replace('\n', ''))
            if text:
                text = ' '.join(text[0].split())
                text = re.sub(r"&#\d*;", '', text)
                data.append({
                    'text': text,
                    'labels': topics
                })

In [3]:
train, test = train_test_split(data, test_size=0.2, random_state=42)

In [4]:
# delete \n and \
for line in train:
    line['text'] = line['text'].replace('\n', ' ')
    line['text'] = line['text'].replace('\\', '')

for line in test:
    line['text'] = line['text'].replace('\n', ' ')
    line['text'] = line['text'].replace('\\', '')

In [5]:
# count classes
train_class_count = defaultdict(int)
test_class_count = defaultdict(int)

for line in train:
    for label in line['labels']:
        train_class_count[label] += 1

for line in test:
    for label in line['labels']:
        test_class_count[label] += 1

unseen_class_count = train_class_count.keys() - test_class_count.keys()

In [6]:
print('Number of train data:', len(train))
print('Number of test data:', len(test))

print('Number of classes in train data:', len(train_class_count))
print('Number of classes in test data:', len(test_class_count))
print('Number of unseen classes:', len(unseen_class_count))

print('Count of train label:\n', sorted(train_class_count.items(), key=lambda x: -x[1]))
print('Count of test label:\n', sorted(test_class_count.items(), key=lambda x: -x[1]))

Number of train data: 8301
Number of test data: 2076
Number of classes in train data: 116
Number of classes in test data: 84
Number of unseen classes: 35
Count of train label:
 [('earn', 3004), ('acq', 1773), ('money-fx', 532), ('grain', 468), ('crude', 451), ('trade', 385), ('interest', 340), ('ship', 246), ('wheat', 235), ('corn', 180), ('sugar', 150), ('oilseed', 144), ('dlr', 133), ('gnp', 122), ('coffee', 115), ('veg-oil', 108), ('gold', 106), ('nat-gas', 101), ('money-supply', 100), ('livestock', 92), ('soybean', 89), ('bop', 83), ('cpi', 79), ('reserves', 64), ('carcass', 64), ('copper', 61), ('jobs', 57), ('cocoa', 57), ('rice', 55), ('iron-steel', 52), ('cotton', 51), ('yen', 50), ('ipi', 47), ('alum', 45), ('meal-feed', 43), ('barley', 40), ('gas', 40), ('rubber', 38), ('palm-oil', 35), ('pet-chem', 33), ('zinc', 32), ('silver', 31), ('sorghum', 31), ('strategic-metal', 27), ('tin', 26), ('lead', 25), ('rapeseed', 25), ('wpi', 24), ('fuel', 22), ('soy-oil', 21), ('soy-meal', 

In [7]:
train_text_length, test_text_length = 0, 0
for line in train:
    train_text_length += len(line['text'])
for line in test:
    test_text_length += len(line['text'])

print('Avg length of train text:', train_text_length / len(train))
print('Avg length of test text:', test_text_length / len(test))

Avg length of train text: 798.7025659559089
Avg length of test text: 790.3684971098265


In [8]:
# label2id
label2id = dict()
cnt = 0

for line in train:
    for label in line['labels']:
        if label not in label2id:
            label2id[label] = cnt
            cnt += 1

for line in test:
    for label in line['labels']:
        if label not in label2id:
            label2id[label] = cnt
            cnt += 1

id2label = {v: k for k, v in label2id.items()}

In [9]:
# to one label, labels[0] is the rarest one
for line in train:
    minn, minn_label = 20000, ''
    for label in line['labels']:
        if train_class_count[label] < minn:
            minn = train_class_count[label]
            minn_label = label
    if line['labels'][0] != minn_label:
        line['labels'].remove(minn_label)
        line['labels'].insert(0, minn_label)

In [10]:
# delete long tail data
label_collection = [[] for _ in range(len(train_class_count))]
for line in train:
    idx = label2id[line['labels'][0]]
    label_collection[idx].append(line)

if state == '50%*1':
    threshold = int(np.percentile([len(t) for t in label_collection], 50))
    label_collection = [t[:threshold * 1] for t in label_collection if len(t) >= threshold]
elif state == '25%*3':
    threshold = int(np.percentile([len(t) for t in label_collection], 25))
    label_collection = [t[:threshold * 3] for t in label_collection if len(t) >= threshold]

In [11]:
# refresh label2id
label2id = dict()
for idx, piece in enumerate(label_collection):
    label2id[piece[0]['labels'][0]] = idx

id2label = {v: k for k, v in label2id.items()}

In [12]:
# delete unseen test
idx = 0
while idx < len(test):
    for label in test[idx]['labels']:
        if label not in label2id:
            test.pop(idx)
            break
    else:
        idx += 1

In [None]:
with open(f'{state}/id2label.json', 'w') as f:
    json.dump(id2label, f, ensure_ascii=False, indent=2)

with open(f'{state}/label2id.json', 'w') as f:
    json.dump(label2id, f, ensure_ascii=False, indent=2)

In [13]:
# shuffle train dataset
collection_pointer = [0] * len(label_collection)

shuffle_train = []
global_idx, idx = 0, 0
while global_idx < sum(len(t) for t in label_collection):
    if collection_pointer[idx] < len(label_collection[idx]):
        shuffle_train.append(label_collection[idx][collection_pointer[idx]])
        collection_pointer[idx] += 1
        idx = (idx + 1) % len(label_collection)
        global_idx += 1
    else:
        idx = (idx + 1) % len(label_collection) 

In [14]:
# subset of test
test = test[:len(shuffle_train) * 3]

In [15]:
print('Number of train data:', len(shuffle_train))
print('Number of test data:', len(test))

Number of train data: 671
Number of test data: 2013


In [16]:
# delete unseen train labels
for line in shuffle_train:
    idx = 0
    while idx < len(line['labels']):
        if line['labels'][idx] not in label2id:
            line['labels'].pop(idx)
        else:
            idx += 1

In [17]:
# count classes
train_class_count = defaultdict(int)
test_class_count = defaultdict(int)

for line in shuffle_train:
    for label in line['labels']:
        train_class_count[label] += 1

for line in test:
    for label in line['labels']:
        test_class_count[label] += 1

unseen_class_count = train_class_count.keys() - test_class_count.keys()
print('Number of train data:', len(shuffle_train))
print('Number of test data:', len(test))

print('Number of classes in train data:', len(train_class_count))
print('Number of classes in test data:', len(test_class_count))
print('Number of unseen classes:', len(unseen_class_count))

print('Count of train label:\n', sorted(train_class_count.items(), key=lambda x: -x[1]))
print('Count of test label:\n', sorted(test_class_count.items(), key=lambda x: -x[1]))

train_text_length, test_text_length = 0, 0
for line in shuffle_train:
    train_text_length += len(line['text'])
for line in test:
    test_text_length += len(line['text'])

print('Avg length of train text:', train_text_length / len(shuffle_train))
print('Avg length of test text:', test_text_length / len(test))

Number of train data: 671
Number of test data: 2013
Number of classes in train data: 61
Number of classes in test data: 61
Number of unseen classes: 0
Count of train label:
 [('grain', 105), ('money-fx', 55), ('oilseed', 50), ('corn', 47), ('wheat', 43), ('veg-oil', 38), ('crude', 35), ('livestock', 35), ('trade', 32), ('meal-feed', 26), ('gnp', 24), ('soybean', 24), ('interest', 21), ('coffee', 20), ('gold', 19), ('ship', 19), ('zinc', 19), ('acq', 18), ('dlr', 18), ('earn', 17), ('rice', 17), ('gas', 17), ('carcass', 16), ('cotton', 16), ('strategic-metal', 15), ('sugar', 15), ('nat-gas', 15), ('rubber', 15), ('barley', 14), ('bop', 14), ('cpi', 14), ('money-supply', 14), ('cocoa', 14), ('soy-meal', 14), ('copper', 13), ('iron-steel', 13), ('jobs', 13), ('ipi', 13), ('silver', 13), ('palm-oil', 13), ('retail', 13), ('housing', 12), ('reserves', 12), ('wpi', 12), ('sorghum', 12), ('soy-oil', 12), ('orange', 12), ('rapeseed', 12), ('lead', 11), ('lei', 11), ('tin', 11), ('pet-chem', 11

In [None]:
# add label description
with open('label-description.jsonl') as f:
    label_desc = f.readlines()
    label_desc = list(map(json.loads, label_desc))
    label_desc = {line['label']: line['description'] for line in label_desc}

for line in shuffle_train:
    label_text = line['labels'][0]
    for i, label in enumerate(line['labels']):
        line['labels'][i] = label2id[label]
    line['label_description'] = label_desc[label_text]

for line in test:
    label_text = line['labels'][0]
    for i, label in enumerate(line['labels']):
        line['labels'][i] = label2id[label]
    line['label_description'] = label_desc[label_text]

with open(f'{state}/train.json', 'w') as f:
    json.dump(shuffle_train, f, ensure_ascii=False, indent=2)

with open(f'{state}/test.json', 'w') as f:
    json.dump(test, f, ensure_ascii=False, indent=2)

In [None]:
# add label description
with open('label-description-with-example.jsonl') as f:
    label_desc = f.readlines()
    label_desc = list(map(json.loads, label_desc))
    label_desc = {line['label']: line['description'] for line in label_desc}

for line in shuffle_train:
    label_text = id2label[line['labels'][0]]
    line['label_description'] = label_desc[label_text]

for line in test:
    label_text = id2label[line['labels'][0]]
    line['label_description'] = label_desc[label_text]

with open(f'{state}/train-with-example.json', 'w') as f:
    json.dump(shuffle_train, f, ensure_ascii=False, indent=2)

with open(f'{state}/test-with-example.json', 'w') as f:
    json.dump(test, f, ensure_ascii=False, indent=2)