# Intuition on AAPD

This notebook contains the exploratory data analysis of the provided Arxiv dataset. 

- Q1: is the dataset the same as the one hosted in HF? A: yes, use this for loading the dataset
- Q2: what knowledge can we gain on the *categories* for subsequent use in training performant models? A: hierarchical labelset; not human-readable labels; exploit category distribution to make balanced validation split
  - Q3: what is the distribution of the number of papers per category? A: plot1.1
  - Q4: what is the distribution of the number of categories per paper? A: Label cardinality: 1.6957565665834098
  - Q5: is there any label noise or inconsistency? A: hard to know, could use a package such as [cleanlab](https://github.com/cleanlab/cleanlab) to estimate label noise
    - Q6: how many papers are in more than one category? A: by checking title if they are unique? 4020 papers with non-unique titles
  - Q7: is there a long-tail of primary-secondary categories? A: plot1.3 
- Q7: 
  - Q8: what is the distribution of the number of words per *abstract*? A: plot2.1
  - Q9: are all abstracts/papers in English? A: predominantly English, used [langdetect](https://pypi.org/project/langdetect/) to detect the language of the abstracts
Q10: Given primary categories/fields is the data somewhat separable? A: plot 3.1: PCA of sentence embeddings per unique category

#### nice documentation 
https://www.kaggle.com/code/matthewmaddock/nlp-arxiv-dataset-transformers-and-umap 


In [1]:
%load_ext autoreload
%autoreload 2
## necessary installs for EDA
!pip3 install numpy
!pip3 install scipy
!pip3 install beautifulsoup4
!pip3 install scikit-learn
!pip3 install matplotlib
!pip3 install pandas
!pip3 install plotly
!pip3 install datasets #hf datasets contains the dataset loader
!pip3 install sentence-transformers #for sentence embeddings EDA on abstracts

You should consider upgrading via the '/home/jordy/.virtualenvs/SOTA/bin/python -m pip install --upgrade pip' command.[0m
You should consider upgrading via the '/home/jordy/.virtualenvs/SOTA/bin/python -m pip install --upgrade pip' command.[0m
You should consider upgrading via the '/home/jordy/.virtualenvs/SOTA/bin/python -m pip install --upgrade pip' command.[0m
You should consider upgrading via the '/home/jordy/.virtualenvs/SOTA/bin/python -m pip install --upgrade pip' command.[0m
You should consider upgrading via the '/home/jordy/.virtualenvs/SOTA/bin/python -m pip install --upgrade pip' command.[0m
You should consider upgrading via the '/home/jordy/.virtualenvs/SOTA/bin/python -m pip install --upgrade pip' command.[0m
You should consider upgrading via the '/home/jordy/.virtualenvs/SOTA/bin/python -m pip install --upgrade pip' command.[0m
You should consider upgrading via the '/home/jordy/.virtualenvs/SOTA/bin/python -m pip install --upgrade pip' command.[0m
You should consi

In [1]:
import pandas as pd
import numpy as np
from tqdm import tqdm
from scipy.stats import describe
from sklearn.model_selection import train_test_split
import plotly.graph_objects as go

In [2]:
from datasets import load_dataset
data = load_dataset("arxiv_dataset", data_dir='/home/jordy/code/opensource/IRIS/AAPD/data', ignore_verifications=True) #has to be a full path due to sloppy coding in hf datasets

## error: Nonmatchingsplitsize (patch HF dataset?) --> ignore_verifications
#NonMatchingSplitsSizesError: [{'expected': SplitInfo(name='train', num_bytes=3056873071, num_examples=2349354, shard_lengths=None, dataset_name=None), 'recorded': SplitInfo(name='train', num_bytes=3132619006, num_examples=2399802, shard_lengths=[415000, 385000, 367000, 350000, 343000, 401000, 138802], dataset_name='arxiv_dataset')}]
data



DatasetDict({
    train: Dataset({
        features: ['id', 'submitter', 'authors', 'title', 'comments', 'journal-ref', 'doi', 'report-no', 'categories', 'license', 'abstract', 'update_date'],
        num_rows: 2399802
    })
})

In [4]:
data['train']['categories'] #is organized according to primary and secondary categories; each category seems to be a hierarchical label (e.g. cs.AI is a subcategory of cs)
# data['train']['abstract'] #abstracts are not tokenized

['hep-ph',
 'math.CO cs.CG',
 'physics.gen-ph',
 'math.CO',
 'math.CA math.FA',
 'cond-mat.mes-hall',
 'gr-qc',
 'cond-mat.mtrl-sci',
 'astro-ph',
 'math.CO',
 'math.NT math.AG',
 'math.NT',
 'math.NT',
 'math.CA math.AT',
 'hep-th',
 'hep-ph',
 'astro-ph',
 'hep-th',
 'math.PR math.AG',
 'hep-ex',
 'nlin.PS physics.chem-ph q-bio.MN',
 'math.NA',
 'astro-ph',
 'nlin.PS',
 'cond-mat.str-el cond-mat.stat-mech',
 'math.RA',
 'cond-mat.mes-hall',
 'math.CA math.PR',
 'hep-ph',
 'cond-mat.str-el',
 'hep-ph',
 'hep-ph',
 'physics.optics physics.comp-ph',
 'q-bio.PE q-bio.CB quant-ph',
 'physics.optics physics.comp-ph',
 'q-bio.QM q-bio.MN',
 'physics.optics physics.comp-ph',
 'physics.optics physics.comp-ph',
 'hep-ph hep-lat nucl-th',
 'math.OA math.FA',
 'math.QA math-ph math.MP',
 'physics.gen-ph quant-ph',
 'cond-mat.stat-mech cond-mat.mtrl-sci',
 'astro-ph nlin.CD physics.plasm-ph physics.space-ph',
 'nlin.PS nlin.SI',
 'quant-ph cs.IT math.IT',
 'cs.NE cs.AI',
 'gr-qc astro-ph',
 'math

In [6]:
## let's look at the distribution of the categories and maybe look for a more human readable version of the categories: https://arxiv.org/category_taxonomy
categories = data['train']['categories']
# split by space and flatten
categories_listed = [cat.split(' ') for cat in categories]
categories_primary_listed = [cat[0] for cat in categories_listed]
categories_secondary_listed = [cat[1] for cat in categories_listed if len(cat) > 1]
categories_secondary_anylisted_na = [cat[1] if len(cat) > 1 else '' for cat in categories_listed]
unique_categories = sorted(set([item for sublist in categories_listed for item in sublist]))

cardinalities = [len(cat) for cat in categories_listed]
print(f"Label cardinality: {describe(cardinalities)}") #1.5 labels per abstract
unique_categories, len(unique_categories) # 176 unique categories

#plotly bar chart
fig = go.Figure(data=[go.Bar(x=unique_categories, y=[categories_primary_listed.count(cat) for cat in unique_categories])])
fig.update_layout(title_text='(1.1) Distribution of primary categories')
fig.show()
fig = go.Figure(data=[go.Bar(x=unique_categories, y=[categories_secondary_listed.count(cat) for cat in unique_categories])])
fig.update_layout(title_text='(1.2) Distribution of secondary categories')
fig.show()

Label cardinality: DescribeResult(nobs=2399802, minmax=(1, 13), mean=1.6957565665834098, variance=0.8557003549128224, skewness=1.4899745134440587, kurtosis=2.450324154023784)


In [5]:
# distribution of primary and secondary categories
fig = go.Figure(data=[go.Bar(x=unique_categories, y=[categories_primary_listed.count(cat) for cat in unique_categories], name='primary'),
                      go.Bar(x=unique_categories, y=[categories_secondary_listed.count(cat) for cat in unique_categories], name='secondary')])
fig.update_layout(title_text='(1.3) Distribution of primary and secondary categories')
fig.show()

In [6]:
import numpy as np

unique_categories_na = unique_categories + ['']
# Create a numpy array to store the correlation values
categories_corr = np.zeros((len(unique_categories_na), len(unique_categories_na)))

# Iterate over the categories_listed and update the correlation matrix
for i in tqdm(range(len(categories_listed))):
    primary = categories_primary_listed[i]
    primary_index = unique_categories_na.index(primary)
    secondary = categories_secondary_anylisted_na[i]
    secondary_index = unique_categories_na.index(secondary)
    categories_corr[primary_index, secondary_index] += 1

# Calculate the correlation matrix
categories_corr = np.corrcoef(categories_corr.T)
categories_corr
fig = go.Figure(data=go.Heatmap(z=categories_corr, x=unique_categories, y=unique_categories))
fig.update_layout(title_text='(1.4) Correlation between primary and secondary categories')
fig.show()

  0%|          | 0/2399802 [00:00<?, ?it/s]

100%|██████████| 2399802/2399802 [00:05<00:00, 407099.54it/s]

invalid value encountered in true_divide


invalid value encountered in true_divide



In [7]:
fields= sorted(set([cat.split(".")[0] for cat in unique_categories]))
fields, len(fields)
#print([cat for cat in unique_categories if categories_primary_listed.count(cat) > 10000]) # 10 categories with more than 10k papers

(['acc-phys',
  'adap-org',
  'alg-geom',
  'ao-sci',
  'astro-ph',
  'atom-ph',
  'bayes-an',
  'chao-dyn',
  'chem-ph',
  'cmp-lg',
  'comp-gas',
  'cond-mat',
  'cs',
  'dg-ga',
  'econ',
  'eess',
  'funct-an',
  'gr-qc',
  'hep-ex',
  'hep-lat',
  'hep-ph',
  'hep-th',
  'math',
  'math-ph',
  'mtrl-th',
  'nlin',
  'nucl-ex',
  'nucl-th',
  'patt-sol',
  'physics',
  'plasm-ph',
  'q-alg',
  'q-bio',
  'q-fin',
  'quant-ph',
  'solv-int',
  'stat',
  'supr-con'],
 38)

In [34]:
underspecified_categories = sorted(set([cat for cat in unique_categories if len(cat.split(".")) == 1])) #underspecified category?
underspecified_categories, len(underspecified_categories), {k:categories_primary_listed.count(k) for k in underspecified_categories} #30 underspecified categories that have no subcategories

(['acc-phys',
  'adap-org',
  'alg-geom',
  'ao-sci',
  'astro-ph',
  'atom-ph',
  'bayes-an',
  'chao-dyn',
  'chem-ph',
  'cmp-lg',
  'comp-gas',
  'cond-mat',
  'dg-ga',
  'funct-an',
  'gr-qc',
  'hep-ex',
  'hep-lat',
  'hep-ph',
  'hep-th',
  'math-ph',
  'mtrl-th',
  'nucl-ex',
  'nucl-th',
  'patt-sol',
  'plasm-ph',
  'q-alg',
  'q-bio',
  'quant-ph',
  'solv-int',
  'supr-con'],
 30,
 {'acc-phys': 46,
  'adap-org': 306,
  'alg-geom': 1209,
  'ao-sci': 13,
  'astro-ph': 94246,
  'atom-ph': 68,
  'bayes-an': 11,
  'chao-dyn': 1770,
  'chem-ph': 129,
  'cmp-lg': 894,
  'comp-gas': 140,
  'cond-mat': 11357,
  'dg-ga': 562,
  'funct-an': 320,
  'gr-qc': 61387,
  'hep-ex': 22487,
  'hep-lat': 17667,
  'hep-ph': 129686,
  'hep-th': 103319,
  'math-ph': 30958,
  'mtrl-th': 165,
  'nucl-ex': 11369,
  'nucl-th': 32661,
  'patt-sol': 452,
  'plasm-ph': 28,
  'q-alg': 1177,
  'q-bio': 0,
  'quant-ph': 104101,
  'solv-int': 844,
  'supr-con': 69})

In [3]:
## reuse code from https://www.kaggle.com/code/lucafuligni/exploratory-data-analysis-arxiv-dataset#Categories to obtain a more human readable version of the categories
## alternate: https://www.kaggle.com/code/matthewmaddock/nlp-arxiv-dataset-transformers-and-umap 
import requests
from bs4 import BeautifulSoup
import pandas as pd

url = "https://arxiv.org/category_taxonomy"

# Send a GET request to the URL
response = requests.get(url)

# Get the page source from the response content
page_source = response.text

# Parse the page source
soup = BeautifulSoup(page_source, 'html.parser')

# Find the category list element
category_list = soup.find(id='category_taxonomy_list')

# Extract the category information
categories = []
main_category = None
for category in category_list.find_all('h4'):
    if category.find_previous('h2'):
        main_category = category.find_previous('h2').text
    category_id = category.text.split(' (')[0]  # Switched with 'category_name'
    category_name = category.text.split('(')[1].split(')')[0]  # Switched with 'category_id'
    category_description = category.find_next('p').text
    categories.append({
        "ID": category_id,  # Switched with 'category_name'
        "Main Category": main_category,
        "Name": category_name,  # Switched with 'category_id'
        "Description": category_description
    })


# Create a dataframe from the categories list
categories_df = pd.DataFrame(categories)
categories_df

Unnamed: 0,ID,Main Category,Name,Description
0,cs.AI,Computer Science,Artificial Intelligence,"Covers all areas of AI except Vision, Robotics..."
1,cs.AR,Computer Science,Hardware Architecture,Covers systems organization and hardware archi...
2,cs.CC,Computer Science,Computational Complexity,"Covers models of computation, complexity class..."
3,cs.CE,Computer Science,"Computational Engineering, Finance, and Science",Covers applications of computer science to the...
4,cs.CG,Computer Science,Computational Geometry,Roughly includes material in ACM Subject Class...
...,...,...,...,...
150,stat.CO,Statistics,Computation,"Algorithms, Simulation, Visualization"
151,stat.ME,Statistics,Methodology,"Design, Surveys, Model Selection, Multiple Tes..."
152,stat.ML,Statistics,Machine Learning,"Covers machine learning papers (supervised, un..."
153,stat.OT,Statistics,Other Statistics,Work in statistics that does not fit into the ...


In [7]:
## manual lookup of missing categories
category_mapping = {
    'acc-phys': 'Accelerator Physics',
  'adap-org': 'Adaptation, Noise, and Self-Organizing Systems',
  'alg-geom': 'Algebraic Geometry',
  'ao-sci': 'Atmospheric and Oceanic Physics',
  'astro-ph': 'Astrophysics',
  'atom-ph': 'Atomic Physics',
  'bayes-an': 'Bayesian Analysis',
  'chao-dyn': 'Chaotic Dynamics',
  'chem-ph': 'Chemical Physics',
  'cmp-lg'  : 'Computation and Language',
  'comp-gas'    : 'Cellular Automata and Lattice Gases',
  'cond-mat': 'Condensed Matter',
  'dg-ga': 'Differential Geometry',
  'funct-an': 'Functional Analysis',
  'mtrl-th': 'Materials Science',
  'patt-sol': 'Pattern Formation and Solitons',
  'plasm-ph': 'Plasma Physics',
  'q-alg': 'Quantum Algebra',
  'q-bio': 'Quantitative Biology',
  'solv-int': 'Exactly Solvable and Integrable Systems',
  'supr-con': 'Superconductivity',
}

for cat in category_mapping:
    categories_df = categories_df.append({'ID': cat, 'Main Category': '', 'Name': category_mapping[cat], 'Description': ''}, ignore_index=True)
categories_df.to_csv('/home/jordy/code/opensource/IRIS/AAPD/data/arxiv_categories.csv', index=False)

In [8]:
# flatten categoriies
categories_flat = [c for cat in categories_listed for c in cat]
categories_df['ID'].nunique(), categories_df['Name'].nunique(), len(unique_categories)

superfluous = [cat for cat in unique_categories if cat not in categories_df['ID'].values]
superfluous, len(superfluous) # 0 superfluous categories
for cat in superfluous:
    print(f"{cat}: 1:{categories_primary_listed.count(cat)} 2:{categories_secondary_listed.count(cat)}")
  ## - some are very infrequent, e.g. 'q-bio' (Quantitative Biology) with 0 primary Category counts

## lets look at the long tail of categories to see if we can remove some categories
long_tail_primary, long_tail_any = [], []
for cat in unique_categories:
    primary_count = categories_primary_listed.count(cat)
    any_count = categories_flat.count(cat)
    if primary_count < 10: #very secondary categories
        long_tail_primary.append(cat)
    if any_count < 50: #very infrequent categories
        long_tail_any.append(cat)
long_tail_primary, long_tail_any, len(long_tail_primary), len(long_tail_any) #
CATEGORIES_TO_REMOVE = set(long_tail_primary + long_tail_any)
CATEGORIES_TO_REMOVE

{'acc-phys',
 'ao-sci',
 'bayes-an',
 'math.IT',
 'math.MP',
 'plasm-ph',
 'q-bio',
 'stat.TH'}

### My understanding of the categories

Fields: (38)

Fields of study/Categories: (176) - at most 2 levels deep with '.', e.g., *acc-phys or CS.AI*

Categories can be chained up to 13 (observed) times, yet on average 1.7 categories per paper. 

Open questions:
- Is the average label cardinality different per field?
- How correlated is the order of categories?
- Is there any label imputation needed?
- Can we translate categories to human readable labels (potentially including the instruction/description) for use in zero-shot classification (e.g., setfit?)? 

Beware that there might be label noise in the dataset from category misuse or mislabeling. (https://blog.arxiv.org/2019/12/05/arxiv-machine-learning-classification-guide/)

From my own experience, I believe there is a lot of noise in the categories, and that the categories are not always used consistently. 
The newer flow of papers might be more consistent, as you need to choose a primary and secondary category only, but the older papers will most likely be more noisy.

In [32]:
#### let's make a subset with 50 papers per primary category for EDA purposes
K_per_category = 50
cat_indices = []
for cat in unique_categories:
    indices = []
    for i, x in enumerate(categories_listed): 
        if x[0] == cat:
            indices.append(i)
            if len(indices) == K_per_category:
                break
    cat_indices.extend(indices)

subset = data['train'].select(cat_indices)

In [15]:
# Q: is the language always English?  --> yes in 99.9% of the cases
from langdetect import detect
from langdetect import DetectorFactory
from collections import Counter
from tqdm import tqdm
DetectorFactory.seed = 0 #for reproducibility
languages = [detect(abstract) for abstract in tqdm(subset['abstract'])]
languages_set = Counter(languages)
languages_set #only English?

100%|██████████| 16835/16835 [01:12<00:00, 231.53it/s]


Counter({'en': 16821, 'de': 2, 'et': 1, 'fr': 8, 'it': 1, 'es': 1, 'da': 1})

In [19]:
## let's look at the length of the abstracts to see if they will fit in Transformer models
abstract_lengths = [len(abstract.split()) for abstract in subset['abstract']]
fig = go.Figure(data=[go.Histogram(x=abstract_lengths, nbinsx=100)])
fig.update_layout(title_text='(2.1) Distribution of abstract lengths')
fig.show()
print([abstract for abstract in subset['abstract'] if len(abstract.split()) < 14]) #short abstracts - withdrawn or plagiarized

In [28]:
print([abstract for abstract in subset['abstract'] if len(abstract.split()) < 14]) #short abstracts
#len(data['train']['title']), len(np.unique(data['train']['title'])) #2349354 --> duplicate titles

['  (Makes a Gamma-acylic coherent resolution of a coherent sheaf on a projection\nscheme.)\n', '  (Generalizes theorem of Atiyah and Mumford.)\n', '  We study limiting lines on degenerations of generic hypersurfaces in $P^n$.\n', '  This is a revised version of the paper submitted before.\n', '  This paper has been withdrawn by the authors.\n', '  A selfcontained proof of the KAM theorem in the Thirring model is discussed.\n', '  We comment on a recent article by Hao and Scheraga.\n', '  See chem-ph/9505003.\n', '  The paper was withdrawn by the authors\n', '  paper withdrawn due to the possible error in numerical eigenfunction\ncalculation\n', '  This paper visualizes a knot reduction algorithm\n', '  This paper has been withdrawn.\n', '  Description of a polynomial time reduction of SAT to 2-SAT of polynomial\nsize.\n', '  Major mistakes do not read\n', '  This paper has been withdrawn\n', '  This paper has been withdrawn\n', '  This paper has been withdrawn\n', '  We describe and p

In [33]:
### sentence encoder to embed every abstract, take averages and plot per primary category

from sentence_transformers import SentenceTransformer
model = SentenceTransformer('all-mpnet-base-v2')

## DEV: alternative make scibert into a sentence encoder
# from transformers import AutoTokenizer, AutoModel
# tokenizer = AutoTokenizer.from_pretrained("allenai/scibert_scivocab_uncased")
# model = AutoModel.from_pretrained("allenai/scibert_scivocab_uncased")
# model = SentenceTransformer(model)

#embed every abstract
abstracts = subset['abstract']
abstract_embeddings = model.encode(abstracts, batch_size=32, show_progress_bar=True)
cats = subset['categories']
cats_primary = [cat.split(" ")[0] for cat in cats]
field_primary = [cat.split(".")[0] for cat in cats_primary]

# plot TSNE of the embeddings
from sklearn.manifold import TSNE
import plotly.express as px
tsne = TSNE(n_components=2, random_state=0)
abstracts_tsne = tsne.fit_transform(abstract_embeddings)
abstracts_tsne_df = pd.DataFrame(abstracts_tsne, columns=['tsne1', 'tsne2'])
abstracts_tsne_df['field_primary'] = field_primary
fig = px.scatter(abstracts_tsne_df, x="tsne1", y="tsne2", color="field_primary")
fig.update_layout(title_text='(2.2) TSNE plot of abstract embeddings')
fig.show()

Batches:   0%|          | 0/266 [00:00<?, ?it/s]

# Preprocessing for training

Now that the EDA is complete, we can define the functions necessary for preprocessing the data for training.

0. removing all outlier values (duplicates, long-tail noise, etc.)
1. keeping primary-secondary categories with at least a certain number of papers (long-tail categories are not very useful for training)
2. creating a balanced validation split
3. defining a labelset for the model, depending on the type of model we want to train
   (0. multi-class classification (primary-only using SciBERT))
   1. zero-shot classification (using the setfit model) --> string labels - human readable/instruction
   2. multi-label classification (using SciBERT) --> list labels
   3. generative classification (FlanT5) --> string labels - human readable/instruction

In [9]:
# conditions: duplicate titles, long-tailed distribution of categories, underspecified categories, language, abstract length

to_list = lambda cat: cat.split(" ") #map the categories to the field
to_field = lambda cats: cats[0].split(".")[0] #map the categories to the field (first)
to_primary = lambda cats: cats[0] #map the categories to the primary categories
to_secondary = lambda cats: cats[1] if len(cats) > 1 else '' #map the categories to the secondary categories (could be used for hacking multi-class classification with the average label cardinality of 1.7); then ensemble

cat_df = pd.read_csv('../data/arxiv_categories.csv', sep=',') 

def stringlabel_mapping(cats, column='Name'):  #map the categories to the string labels
    joiner = []
    for cat in cats:
        if cat in cat_df['ID']:
            joiner.append(cat_df[cat_df['ID'] == cat][column].values[0])
    return ' ; '.join(joiner)    

remove_indices = set()
seen = set()
for i, el in enumerate(data['train']):
    if el['title'] in seen:
        remove_indices.add(i)
    else:
        seen.add(el['title'])
    simple_tokenized = el['abstract'].split()
    if len(simple_tokenized) < 14 or len(simple_tokenized) > 350: 
        remove_indices.add(i)
    if any(cat in CATEGORIES_TO_REMOVE for cat in el['categories'].split()):
        remove_indices.add(i)

data['train'] = data['train'].select([i for i in range(len(data['train'])) if i not in remove_indices])

#complex function to apply to the dataset 
def process_categories(x):
    #x['field_primary'] = to_field(x['categories'])
    x['cats'] = to_list(x['categories']) #list of categories
    x['primary'] = to_primary(x['cats'])
    x['secondary'] = to_secondary(x['cats'])
    x['strlabel'] = stringlabel_mapping(x['cats'])
    return x

#keep only relevant columns (id, labels, abstract)
data['train'] = data['train'].map(process_categories, batch_size=32)

Map:   0%|          | 0/2247644 [00:00<?, ? examples/s]

In [10]:
keep_columns = ['id', 'cats', 'primary', 'secondary', 'strlabel', 'abstract']
remove_columns=[col for col in data['train'].column_names if col not in keep_columns]
data['train'] = data['train'].remove_columns(remove_columns)
remove_columns, data['train']

(['submitter',
  'authors',
  'title',
  'comments',
  'journal-ref',
  'doi',
  'report-no',
  'categories',
  'license',
  'update_date'],
 Dataset({
     features: ['id', 'abstract', 'cats', 'primary', 'secondary', 'strlabel'],
     num_rows: 2247644
 }))

In [14]:
## due to the absence of a validation split, let's make our own split, keeping the label distribution as similar as possible
### the below function could be used for more advanced stratification, to also keep the language or length of the abstract balanced over splits
### keeping it simple for now, as we are only interested in the label distribution for the PoC

'''
from skmultilearn.model_selection import IterativeStratification


def iterative_split(df, stratify_columns, splits=(0.25, 0.75)):
    """Custom iterative train test split which
    'maintains balanced representation with respect
    to order-th label combinations.'

    From http://lpis.csd.auth.gr/publications/sechidis-ecmlpkdd-2011.pdf (page 5)
    """
    # One-hot encode the stratify columns and concatenate them
    one_hot_cols = [pd.get_dummies(df[col]) for col in stratify_columns]
    one_hot_cols = pd.concat(one_hot_cols, axis=1).to_numpy()
    stratifier = IterativeStratification(
        n_splits=2, order=len(stratify_columns), sample_distribution_per_fold=list(splits))
    train_indices, test_indices = next(stratifier.split(df.to_numpy(), one_hot_cols))
    # Return the train and test set dataframes
    train, test = df.iloc[train_indices], df.iloc[test_indices]
    return train, test

train, test = iterative_split(split_df, stratify_columns=["primary","secondary"], splits=(0.25, 0.75))
'''
#train, validation = train_test_split(data['train'], test_size=0.1, random_state=42, stratify=data['train']['strlabel']) #10% as we will train stepwise and might not use the whole trainset
#dataset = DatasetDict({'train':  Dataset.from_dict(train), 'validation': Dataset.from_dict(validation)})

#make new HF dataset wit both splits
from datasets import DatasetDict, Dataset

stratify_column_name = 'strlabel'

dataset = data['train'].class_encode_column(
    stratify_column_name
).train_test_split(
    test_size=0.1, 
    stratify_by_column=stratify_column_name
)

holdout = dataset['test'].train_test_split(
    test_size=0.3, 
    stratify_by_column=stratify_column_name
)
dataset = DatasetDict({'train':  dataset['train'], 'validation': holdout['test'], 'test': holdout['train']})
# save to disk

dataset.save_to_disk('/home/jordy/code/opensource/IRIS/AAPD/data/arxiv_dataset_prepped') #save to disk
#dataset.push_to_hub('arxiv_dataset_prep') #push to HF hub

Saving the dataset (0/5 shards):   0%|          | 0/2022879 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/67430 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/157335 [00:00<?, ? examples/s]