In [1]:
import json
from collections import defaultdict
import re
import numpy as np
from sklearn.model_selection import train_test_split
import pandas as pd

In [2]:
state = '80%*1'

with open('origin/eurovoc_concepts.jsonl') as jsonl_file:
    eurovoc_concepts =  [json.loads(concept) for concept in jsonl_file.readlines()]
    eurovoc_concepts = {concept['id']: concept['title'] for concept in eurovoc_concepts}

train = []
with open('origin/train.jsonl') as f:
    for line in f:
        line = json.loads(line)
        train.append({
            'text': "\n".join([line["header"], line["recitals"]] + line["main_body"]),
            'labels': [eurovoc_concepts[label_idx] for label_idx in line['eurovoc_concepts']]
        })

test = []
with open('origin/test.jsonl') as f:
    for line in f:
        line = json.loads(line)
        test.append({
            'text': "\n".join([line["header"], line["recitals"]] + line["main_body"]),
            'labels': [eurovoc_concepts[label_idx] for label_idx in line['eurovoc_concepts']]
        })

with open('origin/dev.jsonl') as f:
    for line in f:
        line = json.loads(line)
        test.append({
            'text': "\n".join([line["header"], line["recitals"]] + line["main_body"]),
            'labels': [eurovoc_concepts[label_idx] for label_idx in line['eurovoc_concepts']]
        })

In [3]:
# delete \n and \
for line in train:
    line['text'] = line['text'].replace('\n', ' ')
    line['text'] = line['text'].replace('\\', '')

for line in test:
    line['text'] = line['text'].replace('\n', ' ')
    line['text'] = line['text'].replace('\\', '')

In [4]:
# count classes
train_class_count = defaultdict(int)
test_class_count = defaultdict(int)

for line in train:
    for label in line['labels']:
        train_class_count[label] += 1

for line in test:
    for label in line['labels']:
        test_class_count[label] += 1

unseen_class_count = train_class_count.keys() - test_class_count.keys()

In [5]:
print('Number of train data:', len(train))
print('Number of test data:', len(test))

print('Number of classes in train data:', len(train_class_count))
print('Number of classes in test data:', len(test_class_count))
print('Number of unseen classes:', len(unseen_class_count))

print('Count of train label:\n', sorted(train_class_count.items(), key=lambda x: -x[1]))
print('Count of test label:\n', sorted(test_class_count.items(), key=lambda x: -x[1]))

Number of train data: 45000
Number of test data: 12000
Number of classes in train data: 4108
Number of classes in test data: 3087
Number of unseen classes: 1184
Count of train label:
 [('import', 3645), ('export refund', 3554), ('pip fruit', 3005), ('fruit vegetable', 3004), ('citrus fruit', 2988), ('import price', 2959), ('award of contract', 2796), ('tariff quota', 2271), ('third country', 2158), ('originating product', 2152), ('import licence', 2152), ('import (EU)', 1897), ('CCT duties', 1714), ('common organisation of markets', 1609), ('health control', 1599), ('agri-monetary policy', 1413), ('Spain', 1391), ('beef', 1296), ('rice', 1295), ('stone fruit', 1259), ('sea fish', 1253), ('veterinary inspection', 1233), ('cereals', 1181), ('white sugar', 1137), ('marketing', 1106), ('grape', 1096), ('EU aid', 1043), ('catch quota', 1034), ('EU Member State', 1027), ('butter', 1009), ('quantitative restriction', 990), ('Germany', 981), ('Portugal', 952), ("ship's flag", 943), ('export li

In [6]:
train_text_length, test_text_length = 0, 0
for line in train:
    train_text_length += len(line['text'])
for line in test:
    test_text_length += len(line['text'])

print('Avg length of train text:', train_text_length / len(train))
print('Avg length of test text:', test_text_length / len(test))

Avg length of train text: 3432.534888888889
Avg length of test text: 3374.8723333333332


In [7]:
# label2id
label2id = dict()
cnt = 0

for line in train:
    for label in line['labels']:
        if label not in label2id:
            label2id[label] = cnt
            cnt += 1

for line in test:
    for label in line['labels']:
        if label not in label2id:
            label2id[label] = cnt
            cnt += 1

id2label = {v: k for k, v in label2id.items()}

In [8]:
# to one label, labels[0] is the rarest one
for line in train:
    minn, minn_label = 20000, ''
    for label in line['labels']:
        if train_class_count[label] < minn:
            minn = train_class_count[label]
            minn_label = label
    if line['labels'][0] != minn_label:
        line['labels'].remove(minn_label)
        line['labels'].insert(0, minn_label)

In [9]:
# delete long tail data
label_collection = [[] for _ in range(len(train_class_count))]
for line in train:
    idx = label2id[line['labels'][0]]
    label_collection[idx].append(line)

if state == '80%*1':
    threshold = int(np.percentile([len(t) for t in label_collection], 80))
    label_collection = [t[:threshold * 1] for t in label_collection if len(t) >= threshold]

In [10]:
# refresh label2id
label2id = dict()
for idx, piece in enumerate(label_collection):
    label2id[piece[0]['labels'][0]] = idx

id2label = {v: k for k, v in label2id.items()}

In [11]:
# delete unseen test
idx = 0
while idx < len(test):
    for label in test[idx]['labels']:
        if label not in label2id:
            test.pop(idx)
            break
    else:
        idx += 1

In [12]:
with open(f'{state}/id2label.json', 'w') as f:
    json.dump(id2label, f, ensure_ascii=False, indent=2)

with open(f'{state}/label2id.json', 'w') as f:
    json.dump(label2id, f, ensure_ascii=False, indent=2)

In [13]:
# shuffle train dataset
collection_pointer = [0] * len(label_collection)

shuffle_train = []
global_idx, idx = 0, 0
while global_idx < sum(len(t) for t in label_collection):
    if collection_pointer[idx] < len(label_collection[idx]):
        shuffle_train.append(label_collection[idx][collection_pointer[idx]])
        collection_pointer[idx] += 1
        idx = (idx + 1) % len(label_collection)
        global_idx += 1
    else:
        idx = (idx + 1) % len(label_collection) 

In [14]:
# subset of test
test = test[:len(shuffle_train) * 3]

In [15]:
print('Number of train data:', len(shuffle_train))
print('Number of test data:', len(test))

Number of train data: 11089
Number of test data: 1201


In [16]:
# delete unseen train labels
for line in shuffle_train:
    idx = 0
    while idx < len(line['labels']):
        if line['labels'][idx] not in label2id:
            line['labels'].pop(idx)
        else:
            idx += 1

In [17]:
# count classes
train_class_count = defaultdict(int)
test_class_count = defaultdict(int)

for line in shuffle_train:
    for label in line['labels']:
        train_class_count[label] += 1

for line in test:
    for label in line['labels']:
        test_class_count[label] += 1

unseen_class_count = train_class_count.keys() - test_class_count.keys()
print('Number of train data:', len(shuffle_train))
print('Number of test data:', len(test))

print('Number of classes in train data:', len(train_class_count))
print('Number of classes in test data:', len(test_class_count))
print('Number of unseen classes:', len(unseen_class_count))

print('Count of train label:\n', sorted(train_class_count.items(), key=lambda x: -x[1]))
print('Count of test label:\n', sorted(test_class_count.items(), key=lambda x: -x[1]))

train_text_length, test_text_length = 0, 0
for line in shuffle_train:
    train_text_length += len(line['text'])
for line in test:
    test_text_length += len(line['text'])

print('Avg length of train text:', train_text_length / len(shuffle_train))
print('Avg length of test text:', test_text_length / len(test))

Number of train data: 11089
Number of test data: 1201
Number of classes in train data: 853
Number of classes in test data: 560
Number of unseen classes: 293
Count of train label:
 [('tariff quota', 652), ('beef', 363), ('quantitative restriction', 319), ('production aid', 313), ('cereals', 284), ('agricultural product', 283), ('common organisation of markets', 274), ('fishing area', 265), ('butter', 261), ('intervention agency', 252), ('cattle', 252), ('swine', 249), ('aid to agriculture', 239), ('milk product', 236), ('rice', 223), ('product designation', 202), ('export licence', 202), ('textile product', 194), ('milk', 191), ('citrus fruit', 190), ('import policy', 187), ('pigmeat', 182), ('Greece', 180), ('derogation from EU law', 178), ('product quality', 177), ('fishing regulations', 176), ('import price', 175), ('sugar', 174), ('olive oil', 172), ('export', 171), ('poultry', 167), ('China', 166), ('fresh meat', 165), ('fishery product', 162), ('veterinary legislation', 154), ('po

In [18]:
# add label description
with open('label-description.jsonl') as f:
    label_desc = f.readlines()
    label_desc = list(map(json.loads, label_desc))
    label_desc = {line['label']: line['description'] for line in label_desc}

for line in shuffle_train:
    label_text = line['labels'][0]
    for i, label in enumerate(line['labels']):
        line['labels'][i] = label2id[label]
    line['label_description'] = label_desc[label_text]

for line in test:
    label_text = line['labels'][0]
    for i, label in enumerate(line['labels']):
        line['labels'][i] = label2id[label]
    line['label_description'] = label_desc[label_text]

with open(f'{state}/train.json', 'w') as f:
    json.dump(shuffle_train, f, ensure_ascii=False, indent=2)

with open(f'{state}/test.json', 'w') as f:
    json.dump(test, f, ensure_ascii=False, indent=2)

In [19]:
# add label description
with open('label-description-with-example.jsonl') as f:
    label_desc = f.readlines()
    label_desc = list(map(json.loads, label_desc))
    label_desc = {line['label']: line['description'] for line in label_desc}

for line in shuffle_train:
    label_text = id2label[line['labels'][0]]
    line['label_description'] = label_desc[label_text]

for line in test:
    label_text = id2label[line['labels'][0]]
    line['label_description'] = label_desc[label_text]

with open(f'{state}/train-with-example.json', 'w') as f:
    json.dump(shuffle_train, f, ensure_ascii=False, indent=2)

with open(f'{state}/test-with-example.json', 'w') as f:
    json.dump(test, f, ensure_ascii=False, indent=2)