In [1]:
import pickle
import os
import json
import numpy as np
import pandas as pd
import re
from sklearn.model_selection import train_test_split

In [2]:
method = 'pool'

dataset_name = 'wos46985'
base_dir = '../../data/WOS/'
data_file = base_dir+'Meta-data/Data.csv'

text_embedding_dir = '../data/'+dataset_name

seed_indices_file = text_embedding_dir+'/seed_indices.json'
remaining_indices_file = text_embedding_dir+'/remaining_indices.json'

In [3]:
df = pd.read_csv(data_file)
print('Num Data:',len(df))

Num Data: 46985


In [4]:
df.head(2)

Unnamed: 0,Y1,Y2,Y,Domain,area,keywords,Abstract
0,0,12,12,CS,Symbolic computation,(2+1)-dimensional non-linear optical waves; e...,"""(2 + 1)-dimensional non-linear optical waves ..."
1,5,2,74,Medical,Alzheimer's Disease,Aging; Tau; Amyloid; PET; Alzheimer's disease...,"""(beta-amyloid (A beta) and tau pathology beco..."


In [5]:
print('Num L1 Classes:', len(df.Y1.unique()))
print('Num L2 Classes:', len(df.Y.unique()))

Num L1 Classes: 7
Num L2 Classes: 134


In [6]:
def split_seed_df(Y1, Y2, sample_N=3):
    text_df = df[(df.Y1==Y1)&(df.Y2==Y2)]
    grp_indices = text_df.index

    if len(grp_indices)>1:
        if len(grp_indices) <= sample_N:
            sample_N = int(len(grp_indices)/2)
        remaining_indices, seed_indices = train_test_split(grp_indices, test_size=sample_N)
    else:
        seed_indices = text_df.index
        remaining_indices = text_df.index

    return (seed_indices.values, remaining_indices.values)

def get_parent_seed(Y1, sample_N=8):
    parent_all_seeds = df[(df.Y1==Y1)].index
    _, seed_indices = train_test_split(parent_all_seeds, test_size=sample_N)

    return seed_indices

In [7]:
num_sample = 4

seeds = dict()
remaining = dict()
class_groups = df.groupby(['Y1','Y2'])['Y'].size().reset_index()
for parent in class_groups.Y1.unique():
    parent = parent.item()
    seeds[parent] = dict()
    seeds[parent]['indices'] = []
    seeds[parent]['sub-topic'] = dict()
    remaining[parent] = dict()
    remaining[parent]['sub-topic'] = dict()
for i,class_group in class_groups.iterrows():
    parent, child = class_group.Y1.item(), class_group.Y2.item()
    seeds[parent]['sub-topic'][child] = dict()
    seeds[parent]['sub-topic'][child]['indices'] = []
    remaining[parent]['sub-topic'][child] = dict()
    remaining[parent]['sub-topic'][child]['indices'] = []
    
for i,class_group in class_groups.iterrows():
    parent, child = class_group.Y1.item(), class_group.Y2.item()
    seed, dat = split_seed_df(parent, child, num_sample)
    seeds[parent]['sub-topic'][child]['indices'] = seed.tolist()
    remaining[parent]['sub-topic'][child]['indices'] = dat.tolist()

In [8]:
num_sample = 8

for parent in class_groups.Y1.unique():
    parent = parent.item()
    seed = get_parent_seed(parent, num_sample)
    seeds[parent]['indices'] = seed.tolist()

In [9]:
for i,class_group in class_groups.iterrows():
    parent, child = class_group.Y1.item(), class_group.Y2.item()
    remaining[parent]['sub-topic'][child]['indices'] = [
        x for x in remaining[parent]['sub-topic'][child]['indices'] if x not in seeds[parent]['indices']
    ]

In [10]:
# seeds, remaining

In [11]:
with open(seed_indices_file, "w") as outfile:
    json.dump(seeds, outfile, indent = 4)

In [12]:
with open(remaining_indices_file, "w") as outfile:
    json.dump(remaining, outfile, indent = 4)