In [24]:
import json
from collections import defaultdict
import re
import numpy as np
import random

In [25]:
state = '25%*3'

with open('origin/topics.rbb') as f:
    label_desc_raw = f.readlines()
    label_desc = dict()
    for line in label_desc_raw:
        if '      ' in line:
            line = line.strip().split('      ')
        elif '     ' in line:
            line = line.strip().split('     ')
        elif '    ' in line:
            line = line.strip().split('    ')
        label_desc[line[0]] = line[1]

with open('label-description.jsonl', 'w') as f:
    for k, v in label_desc.items():
        json.dump({'label': k, 'description': v}, f)
        f.write('\n')

In [26]:
with open('origin/rcv1-v2.topics.qrels') as f:
    topic_did = f.readlines()
    topic_did = list(map(lambda x: x.strip().split(' '), topic_did))
    topic_did_dict = defaultdict(list)
    for label, did, _ in topic_did:
        topic_did_dict[did].append(label)

In [27]:
def load_data(data):
    ret_data = []
    piece = {'text': ''}
    for line in data:
        if re.findall('\.I.(.*?)\n', line):
            if piece['text']:
                piece['text'] = piece['text'].strip()
                ret_data.append(piece)
                piece = {'text': ''}
            piece['labels'] = topic_did_dict[re.findall('\.I.(.*?)\n', line)[0]]
            continue
        if re.findall('\.W\n', line):
            continue
        piece['text'] += line.strip() + ' '
    
    if piece['text']:
        piece['text'] = piece['text'].strip()
        ret_data.append(piece)

    return ret_data

In [28]:
with open('origin/lyrl2004_tokens_train.dat') as f:
    train = f.readlines()
    train = load_data(train)

In [29]:
test = []
for pt in ['pt0', 'pt1', 'pt2', 'pt3']:
    with open(f'origin/lyrl2004_tokens_test_{pt}.dat') as f:
        test += f.readlines()
test = load_data(test)

In [30]:
# delete \n and \
for line in train:
    line['text'] = line['text'].replace('\n', ' ')
    line['text'] = line['text'].replace('\\', '')

for line in test:
    line['text'] = line['text'].replace('\n', ' ')
    line['text'] = line['text'].replace('\\', '')

In [31]:
# count classes
train_class_count = defaultdict(int)
test_class_count = defaultdict(int)

for line in train:
    for label in line['labels']:
        train_class_count[label] += 1

for line in test:
    for label in line['labels']:
        test_class_count[label] += 1

unseen_class_count = train_class_count.keys() - test_class_count.keys()

In [32]:
print('Number of train data:', len(train))
print('Number of test data:', len(test))

print('Number of classes in train data:', len(train_class_count))
print('Number of classes in test data:', len(test_class_count))
print('Number of unseen classes:', len(unseen_class_count))

print('Count of train label:\n', sorted(train_class_count.items(), key=lambda x: -x[1]))
print('Count of test label:\n', sorted(test_class_count.items(), key=lambda x: -x[1]))

Number of train data: 23149
Number of test data: 781265
Number of classes in train data: 101
Number of classes in test data: 103
Number of unseen classes: 0
Count of train label:
 [('CCAT', 10786), ('GCAT', 6970), ('MCAT', 5882), ('C15', 4179), ('ECAT', 3449), ('M14', 2541), ('C151', 2366), ('C152', 1930), ('GPOL', 1647), ('M13', 1596), ('M141', 1508), ('C18', 1462), ('M11', 1294), ('E21', 1255), ('C181', 1205), ('C17', 1172), ('GCRIM', 1133), ('GVIO', 1115), ('C31', 1058), ('GDIP', 1004), ('C13', 947), ('M131', 943), ('C24', 922), ('GSPO', 913), ('E212', 853), ('C21', 793), ('M12', 732), ('M132', 699), ('E12', 679), ('C11', 674), ('E51', 641), ('M143', 606), ('GJOB', 471), ('E41', 449), ('C33', 443), ('C171', 437), ('E211', 407), ('E512', 400), ('C1511', 399), ('C12', 381), ('G15', 363), ('GVOTE', 346), ('C42', 343), ('C41', 312), ('M142', 311), ('GDIS', 293), ('C411', 286), ('C172', 285), ('E11', 279), ('C174', 246), ('GDEF', 233), ('C183', 202), ('GHEA', 197), ('C312', 196), ('C22',

In [33]:
train_text_length, test_text_length = 0, 0
for line in train:
    train_text_length += len(line['text'])
for line in test:
    test_text_length += len(line['text'])

print('Avg length of train text:', train_text_length / len(train))
print('Avg length of test text:', test_text_length / len(test))

Avg length of train text: 746.1848459976673
Avg length of test text: 764.0679462154327


In [34]:
# label2id
label2id = dict()
cnt = 0

for line in train:
    for label in line['labels']:
        if label not in label2id:
            label2id[label] = cnt
            cnt += 1

for line in test:
    for label in line['labels']:
        if label not in label2id:
            label2id[label] = cnt
            cnt += 1

id2label = {v: k for k, v in label2id.items()}

In [35]:
# to one label, labels[0] is the rarest one
for line in train:
    minn, minn_label = 20000, ''
    for label in line['labels']:
        if train_class_count[label] < minn:
            minn = train_class_count[label]
            minn_label = label
    if line['labels'][0] != minn_label:
        line['labels'].remove(minn_label)
        line['labels'].insert(0, minn_label)

In [36]:
# delete long tail data
label_collection = [[] for _ in range(len(train_class_count))]
for line in train:
    idx = label2id[line['labels'][0]]
    label_collection[idx].append(line)

if state == '50%*1':
    threshold = int(np.percentile([len(t) for t in label_collection], 50))
    label_collection = [t[:threshold * 1] for t in label_collection if len(t) >= threshold]
elif state == '25%*3':
    threshold = int(np.percentile([len(t) for t in label_collection], 25))
    label_collection = [t[:threshold * 3] for t in label_collection if len(t) >= threshold]

In [37]:
# refresh label2id
label2id = dict()
for idx, piece in enumerate(label_collection):
    label2id[piece[0]['labels'][0]] = idx

id2label = {v: k for k, v in label2id.items()}

In [38]:
# delete unseen test
idx = 0
while idx < len(test):
    for label in test[idx]['labels']:
        if label not in label2id:
            test.pop(idx)
            break
    else:
        idx += 1

In [39]:
with open(f'{state}/id2label.json', 'w') as f:
    json.dump(id2label, f, ensure_ascii=False, indent=2)

with open(f'{state}/label2id.json', 'w') as f:
    json.dump(label2id, f, ensure_ascii=False, indent=2)

In [40]:
# shuffle train dataset
collection_pointer = [0] * len(label_collection)

shuffle_train = []
global_idx, idx = 0, 0
while global_idx < sum(len(t) for t in label_collection):
    if collection_pointer[idx] < len(label_collection[idx]):
        shuffle_train.append(label_collection[idx][collection_pointer[idx]])
        collection_pointer[idx] += 1
        idx = (idx + 1) % len(label_collection)
        global_idx += 1
    else:
        idx = (idx + 1) % len(label_collection) 

In [41]:
# subset of test
test = test[:len(shuffle_train) * 3]

In [42]:
print('Number of train data:', len(shuffle_train))
print('Number of test data:', len(test))

Number of train data: 6465
Number of test data: 19395


In [43]:
# delete unseen train labels
for line in shuffle_train:
    idx = 0
    while idx < len(line['labels']):
        if line['labels'][idx] not in label2id:
            line['labels'].pop(idx)
        else:
            idx += 1

In [44]:
# count classes
train_class_count = defaultdict(int)
test_class_count = defaultdict(int)

for line in shuffle_train:
    for label in line['labels']:
        train_class_count[label] += 1

for line in test:
    for label in line['labels']:
        test_class_count[label] += 1

unseen_class_count = train_class_count.keys() - test_class_count.keys()
print('Number of train data:', len(shuffle_train))
print('Number of test data:', len(test))

print('Number of classes in train data:', len(train_class_count))
print('Number of classes in test data:', len(test_class_count))
print('Number of unseen classes:', len(unseen_class_count))

print('Count of train label:\n', sorted(train_class_count.items(), key=lambda x: -x[1]))
print('Count of test label:\n', sorted(test_class_count.items(), key=lambda x: -x[1]))

train_text_length, test_text_length = 0, 0
for line in shuffle_train:
    train_text_length += len(line['text'])
for line in test:
    test_text_length += len(line['text'])

print('Avg length of train text:', train_text_length / len(shuffle_train))
print('Avg length of test text:', test_text_length / len(test))

Number of train data: 6465
Number of test data: 19395
Number of classes in train data: 76
Number of classes in test data: 25
Number of unseen classes: 51
Count of train label:
 [('GCAT', 2522), ('M14', 628), ('GPOL', 545), ('C31', 502), ('C13', 495), ('E51', 402), ('E12', 388), ('GCRIM', 375), ('G15', 338), ('C24', 305), ('M141', 299), ('GDIP', 291), ('C181', 282), ('C151', 255), ('C21', 237), ('GVIO', 228), ('E212', 220), ('E41', 219), ('C152', 214), ('M132', 213), ('C11', 197), ('GDIS', 194), ('M131', 183), ('M11', 178), ('E512', 177), ('E211', 176), ('GHEA', 163), ('C171', 162), ('C12', 153), ('E11', 139), ('C33', 139), ('GPRO', 138), ('C312', 134), ('M12', 129), ('G154', 129), ('M143', 124), ('GENV', 121), ('C34', 120), ('GVOTE', 119), ('C22', 118), ('C42', 118), ('E131', 116), ('GDEF', 111), ('C311', 110), ('M142', 108), ('C183', 108), ('C182', 107), ('C172', 105), ('GENT', 103), ('C411', 103), ('E71', 103), ('C174', 101), ('C1511', 101), ('GSPO', 101), ('GWEA', 101), ('C14', 99),

In [45]:
# add label description
with open('label-description.jsonl') as f:
    label_desc = f.readlines()
    label_desc = list(map(json.loads, label_desc))
    label_desc = {line['label']: line['description'] for line in label_desc}

for line in shuffle_train:
    label_text = line['labels'][0]
    for i, label in enumerate(line['labels']):
        line['labels'][i] = label2id[label]
    line['label_description'] = label_desc[label_text]

for line in test:
    label_text = line['labels'][0]
    for i, label in enumerate(line['labels']):
        line['labels'][i] = label2id[label]
    line['label_description'] = label_desc[label_text]

with open(f'{state}/train.json', 'w') as f:
    json.dump(shuffle_train, f, ensure_ascii=False, indent=2)

with open(f'{state}/test.json', 'w') as f:
    json.dump(test, f, ensure_ascii=False, indent=2)

In [46]:
# add label description
with open('label-description-with-example.jsonl') as f:
    label_desc = f.readlines()
    label_desc = list(map(json.loads, label_desc))
    label_desc = {line['label']: line['description'] for line in label_desc}

for line in shuffle_train:
    label_text = id2label[line['labels'][0]]
    line['label_description'] = label_desc[label_text]

for line in test:
    label_text = id2label[line['labels'][0]]
    line['label_description'] = label_desc[label_text]

with open(f'{state}/train-with-example.json', 'w') as f:
    json.dump(shuffle_train, f, ensure_ascii=False, indent=2)

with open(f'{state}/test-with-example.json', 'w') as f:
    json.dump(test, f, ensure_ascii=False, indent=2)