In [23]:
from datasets import load_dataset
import re
import pprint

In [None]:
ds_raw_train = load_dataset("TimSchopf/arxiv_categories", "default", split="train")
ds_cats = load_dataset("TimSchopf/arxiv_categories",  "arxiv_category_descriptions",
                       split="arxiv_category_descriptions")

In [26]:
pprint.pprint(ds_raw_train[0:2])
pprint.pprint(ds_cats[0:2])

{'abstract': ['In order to read meter values from a camera on an autonomous '
              'inspection robot with positional errors, it is necessary to '
              'detect meter regions from the image. In this study, we '
              'developed shape-based, texture-based, and background '
              'information-based methods as meter area detection techniques '
              'and compared their effectiveness for meters of different shapes '
              'and sizes. As a result, we confirmed that the background '
              'information-based method can detect the farthest meters '
              'regardless of the shape and number of meters, and can stably '
              'detect meters with a diameter of 40px.',
              'We consider the problem of finding the transition rates of a '
              'continuous-time homogeneous Markov chain under the empirical '
              'condition that the state changes at most once during a time '
              'interval of uni

## Structure of data
Two different Dataset objects 
1. "default" which contains the follwoign relevant features for our problem:
    - title,abstract, categories
   
2. "arxiv_category_description" which contains features that only describe each category
    - tag, name

Therefore, We will need the 'tag' feature of the second dataset to identify all the categories
available in this dataset and then map that to information in the 'categories' feature of
the identify the category specific to that datapoint

**NOTE** - there maybe multiple "tags" within the "categories" field for a given paper, 
since a paper coudl actually belong to multiple category. But perhaps, its best to label
each paper based on the first "tag" encountered in "categories".

In [None]:
def get_all_categories(category_ds):
    all_categories = []
    for c in category_ds['tag']:
        all_categories.append(c)
    return all_categories

all_categories = get_all_categories(ds_cats)
pprint.pprint(all_categories) # Let's take a look at all the category tags

['cs.AI',
 'cs.AR',
 'cs.CC',
 'cs.CE',
 'cs.CG',
 'cs.CL',
 'cs.CR',
 'cs.CV',
 'cs.CY',
 'cs.DB',
 'cs.DC',
 'cs.DL',
 'cs.DM',
 'cs.DS',
 'cs.ET',
 'cs.FL',
 'cs.GL',
 'cs.GR',
 'cs.GT',
 'cs.HC',
 'cs.IR',
 'cs.IT',
 'cs.LG',
 'cs.LO',
 'cs.MA',
 'cs.MM',
 'cs.MS',
 'cs.NA',
 'cs.NE',
 'cs.NI',
 'cs.OH',
 'cs.OS',
 'cs.PF',
 'cs.PL',
 'cs.RO',
 'cs.SC',
 'cs.SD',
 'cs.SE',
 'cs.SI',
 'cs.SY',
 'econ.EM',
 'econ.GN',
 'econ.TH',
 'eess.AS',
 'eess.IV',
 'eess.SP',
 'eess.SY',
 'math.AC',
 'math.AG',
 'math.AP',
 'math.AT',
 'math.CA',
 'math.CO',
 'math.CT',
 'math.CV',
 'math.DG',
 'math.DS',
 'math.FA',
 'math.GM',
 'math.GN',
 'math.GR',
 'math.GT',
 'math.HO',
 'math.IT',
 'math.KT',
 'math.LO',
 'math.MG',
 'math.MP',
 'math.NA',
 'math.NT',
 'math.OA',
 'math.OC',
 'math.PR',
 'math.QA',
 'math.RA',
 'math.RT',
 'math.SG',
 'math.SP',
 'math.ST',
 'astro-ph',
 'astro-ph.CO',
 'astro-ph.EP',
 'astro-ph.GA',
 'astro-ph.HE',
 'astro-ph.IM',
 'astro-ph.SR',
 'cond-mat',
 'cond-mat

In [63]:
ds_raw_train[:30]['categories']

[['Computer Science Archive->cs.CV'],
 ['Mathematics Archive->math.PR'],
 ['Physics Archive->astro-ph->astro-ph.EP'],
 ['Computer Science Archive->cs.DB'],
 ['Physics Archive->nlin->nlin.CD'],
 ['Physics Archive->astro-ph->astro-ph.SR'],
 ['Physics Archive->physics->physics.atom-ph'],
 ['Computer Science Archive->cs.CV', 'Computer Science Archive->cs.LG'],
 ['Physics Archive->astro-ph->astro-ph.GA',
  'Physics Archive->astro-ph->astro-ph.SR'],
 ['Mathematics Archive->math.OC'],
 ['Physics Archive->astro-ph->astro-ph.GA',
  'Physics Archive->astro-ph->astro-ph.SR'],
 ['Mathematics Archive->math.PR'],
 ['Physics Archive->cond-mat->cond-mat.mtrl-sci'],
 ['Mathematics Archive->math.AP'],
 ['Physics Archive->gr-qc',
  'Physics Archive->hep->hep-th',
  'Physics Archive->quant-ph'],
 ['Physics Archive->gr-qc', 'Physics Archive->hep->hep-th'],
 ['Computer Science Archive->cs.CC'],
 ['Physics Archive->gr-qc'],
 ['Mathematics Archive->math.AT'],
 ['Physics Archive->astro-ph->astro-ph.GA',
  'Phy

In [67]:
ds_raw_train[7]['categories']

['Computer Science Archive->cs.CV', 'Computer Science Archive->cs.LG']

## Pattern matching strategy

1. The category field in the datast may contain 1 or more categories. Therefore, we need
   to process each of these

2. The pattern to match i.e the thing that appears in the tags alwats appears after a -> 
   ,and is either flanked by another -> or a white space?

3. Therefore, we will need a function that extracts the parts of the 'categories' string 
   that has the following pattern : ->x-> or ->x. This function will need to be applied
   to each category for a specific datapoint.  

4. **Always** the first -> appears after the word Archive

In [None]:
def get_category_strings(example):
    
    test_strings = example['categories']
    # This regex seems to work for most examples. Will require further testing
    pattern = re.compile(r'[\w\s]*Archive(->)([A-Za-z.-]+)(>|\w*)?([A-Za-z.-]*)')
    categories = []
    for string in test_strings:
        matched_pattern = pattern.sub(r'\2\3\4',string)
        categories.append(matched_pattern)

    return categories


In [272]:
extracted_categories = get_category_strings(ds_raw_train[54])
print(extracted_categories)
print(ds_raw_train[54]['categories'])

['quant-ph']
['Physics Archive->quant-ph']


In [252]:
pattern = re.compile(r'[\w\s]*Archive(->)([A-Za-z.-]+)(>|\w*)?([A-Za-z.-]*)')
idx = 17
matches = re.finditer(pattern, ds_raw_train[idx]['categories'][0])
sub_matches = pattern.sub(r'\2\3\4',ds_raw_train[idx]['categories'][0])
for match in matches:
    print(match)
    print(sub_matches)
    # print(match.group(1))
    # print(match.group(2))
    # print(match.group(3))


ds_raw_train[idx]['categories'][0]


<re.Match object; span=(0, 22), match='Physics Archive->gr-qc'>
gr-qc


'Physics Archive->gr-qc'