In [41]:
import json

# "label": ["People", "Artists, musicians, and composers", "Musicians and composers"]

ids, labels1, labels2 = [], [], []
with open('wikivitals-lvl5-04-2022.json', 'r', encoding='utf8') as f:
    for line in f:
        observation = json.loads(line)
        ids.append(observation['id'])
        labels1.append(observation['label'][0] + ' ->- ' + observation['label'][1])
        labels2.append(observation['label'][0] + ' ->- ' + observation['label'][1]+ ' ->- ' + observation['label'][2])
print(len(ids), len(labels1),len(labels2))
total_number_of_articles = len(ids)
ids2lsitindex = {ids[i]:i for i in range(len(ids))}
print(ids[:10])
print(labels1[:10])
print(labels2[:10])
print(f"Number of labels 1 and labels 2: {len(set(labels1)), len(set(labels2))}")

48512 48512 48512
['1000', '100005', '100011', '100012', '10002', '100031', '100034', '100041', '1000530', '10005756']
['Arts ->- Arts', 'Biological and health sciences ->- Animals', 'Geography ->- Countries', 'People ->- Philosophers, historians, political and social scientists', 'People ->- Philosophers, historians, political and social scientists', 'People ->- Writers and journalists', 'Technology ->- Technology', 'People ->- Politicians and leaders', 'Geography ->- Countries', 'Mathematics ->- Mathematics']
['Arts ->- Arts ->- Fictional characters', 'Biological and health sciences ->- Animals ->- Birds', 'Geography ->- Countries ->- Regions and country subdivisions', 'People ->- Philosophers, historians, political and social scientists ->- Social scientists', 'People ->- Philosophers, historians, political and social scientists ->- Social scientists', 'People ->- Writers and journalists ->- Writers', 'Technology ->- Technology ->- General', 'People ->- Politicians and leaders ->- A

In [27]:
# Find level 2 classes with strictly less than 3 observations (ie. classes that cannot be split into 3 sets (train, validation, and test))
from collections import Counter
C = Counter(labels2)
labels_w_special_treatment = []
for label, num_obs in C.items():
    labels_w_special_treatment.append(label) if num_obs < 3 else None
print(labels_w_special_treatment)

# Filter these labels
ids_, labels1_, labels2_ = [], [], []
ids_with_special_treatment = []
for i in range(len(ids)):
    if not labels2[i] in labels_w_special_treatment:
        ids_.append(ids[i])
        labels1_.append(labels1[i])
        labels2_.append(labels2[i])
    else: 
        ids_with_special_treatment.append(ids[i])
print(len(ids_), len(labels1_),len(labels2_))
print(ids_with_special_treatment)


['Physical sciences ->- Earth science ->- Earth science basics', 'People ->- Miscellaneous ->- Micronations']
48509 48509 48509
['20653168', '2890783', '31749258']


In [28]:
from sklearn.model_selection import train_test_split
import numpy as np


# Define the split sizes
train_size, validation_size, test_size = 0.81, 0.09, 0.10

relative_validation_size = validation_size / (validation_size + train_size)

X = np.array(ids_)
y1 = labels1_
y2 = labels2_

X_train, X_test, y_train, y_test = train_test_split(X, y2, test_size=test_size, random_state=42, shuffle=True, stratify=y2)
print(X_train.shape, X_test.shape, len(y_train), len(y_test))

X_train, X_val, y_train, y_val = train_test_split(X_train, y_train, test_size=relative_validation_size, random_state=42, shuffle=True, stratify=y_train)
print(X_train.shape, X_val.shape, len(y_train), len(y_val))

(43658,) (4851,) 43658 4851
(39292,) (4366,) 39292 4366


In [50]:
# Get the ids of the documents per split
ids_train = list(X_train)
ids_val = list(X_val)
ids_test = list(X_test)
# Reintegrate the labels we omitted (num. of observations < 3)
# All the missing articles are integrated into the train set beacause the total number
# of missing article is << than the total number of articles already in the train set 
# and it does not affect the overall split sizes 
ids_train = ids_train + ids_with_special_treatment

# Sanity check

# Display some statistics:
stats = [len(ids_train), len(ids_val), len(ids_test)]
stats = [f"{s/total_number_of_articles:.2f}" for s in stats]
print(f"Split sizes: {stats}")

# Check that all labels have at least one observation in the train set
indices_train = [ids2lsitindex[id] for id in ids_train]
labels1_train = [labels1[i] for i in indices_train]
labels2_train = [labels2[i] for i in indices_train]
print(f"Number of labels in the train set (labels 1 and labels 2): {len(set(labels1_train)), len(set(labels2_train))}")

print('\n')
indices_val = [ids2lsitindex[id] for id in ids_val]
indices_test = [ids2lsitindex[id] for id in ids_test]
labels1_val = [labels1[i] for i in indices_val]
labels2_val = [labels2[i] for i in indices_val]
labels1_test = [labels1[i] for i in indices_test]
labels2_test = [labels2[i] for i in indices_test]

C1_train, C1_val, C1_test = Counter(labels1_train), Counter(labels1_val), Counter(labels1_test)
C2_train, C2_val, C2_test = Counter(labels2_train), Counter(labels2_val), Counter(labels2_test)
all_labels1, all_labels2 = sorted(list(set(labels1))), sorted(list(set(labels2)))
C1 = {l:[0, 0, 0] for l in all_labels1}
C2 = {l:[0, 0, 0] for l in all_labels2}
for l in all_labels1:
    C1[l][0] = 0 if not l in C1_train.keys() else C1_train[l] / len(ids_train)
    C1[l][1] = 0 if not l in C1_val.keys() else C1_val[l] / len(ids_val)
    C1[l][2] = 0 if not l in C1_test.keys() else C1_test[l] / len(ids_test)
for l in all_labels2:
    C2[l][0] = 0 if not l in C2_train.keys() else C2_train[l]  / len(ids_train)
    C2[l][1] = 0 if not l in C2_val.keys() else C2_val[l]  / len(ids_val)
    C2[l][2] = 0 if not l in C2_test.keys() else C2_test[l]  / len(ids_test)

print('\n')
for l in all_labels1:
    print(f"{l} : {C1[l]}")

print('\n')
for l in all_labels2:
    print(f"{l} : {C2[l]}")

Split sizes: ['0.81', '0.09', '0.10']
Number of labels in the train set (labels 1 and labels 2): (32, 251)




Arts ->- Arts : [0.06827840692200025, 0.06802565277141548, 0.06802721088435375]
Biological and health sciences ->- Animals : [0.04944649446494465, 0.04924415941365094, 0.049062049062049064]
Biological and health sciences ->- Biology : [0.01827204478941341, 0.01809436555199267, 0.01834673263244692]
Biological and health sciences ->- Health : [0.016312507952665735, 0.01626202473660101, 0.016285301999587713]
Biological and health sciences ->- Plants : [0.012546125461254613, 0.012368300503893724, 0.012574726860441147]
Everyday life ->- Everyday life : [0.024557831785214403, 0.024736601007787448, 0.02432488146773861]
Everyday life ->- Sports, games and recreation : [0.02534673622598295, 0.025423728813559324, 0.025561739847454135]
Geography ->- Cities : [0.04183738389108029, 0.04191479615208429, 0.04184704184704185]
Geography ->- Countries : [0.028553251049751878, 0.0286303252404947

In [53]:
# Save the sets
file_prefix = 'wikivitals-lvl5-04-2022-dense' # dense because the split is 81%/9%/10% as in my article 'Fair evaluation of GMNN'
with open('wikivitals-lvl5-04-2022.json', 'r', encoding='utf8') as f:
    train_file = open(f'{file_prefix}_train.json', 'w', encoding='utf8')
    val_file = open(f'{file_prefix}_val.json', 'w', encoding='utf8')
    test_file = open(f'{file_prefix}_test.json', 'w', encoding='utf8')

    for line in f:
        observation = json.loads(line)
        if observation['id'] in ids_test:
            test_file.write(json.dumps(observation) + '\n')
        elif observation['id'] in ids_val:
            val_file.write(json.dumps(observation) + '\n')
        else:
            train_file.write(json.dumps(observation)+ '\n') # train set is the biggest, so we avoid the search in the list

    train_file.close()
    val_file.close()
    test_file.close()


In [54]:
# let's select 20 observations for each level 1 labels in the train set previously extracted (32 * 20 = 640 observations)
from random import sample

ids_train_sparse = []

label2ids_train = {l:[] for l in all_labels1}
for i in range(len(labels1_train)):
    label2ids_train[labels1_train[i]].append(ids_train[i])
for l in all_labels1:
    sampled_observations = sample(label2ids_train[l], 20)
    ids_train_sparse = ids_train_sparse + sampled_observations


# Save the new split
file_prefix = 'wikivitals-lvl5-04-2022-sparse' # sparse because it contains 640 observations only (among the 48512 observations in the corpus)

with open('wikivitals-lvl5-04-2022-dense_train.json', 'r', encoding='utf8') as f:
    sparse_train_file = open(f'{file_prefix}_train.json', 'w', encoding='utf8')
    for line in f:
        observation = json.loads(line)
        if observation['id'] in ids_train_sparse:
            sparse_train_file.write(json.dumps(observation) + '\n')
    sparse_train_file.close()
    