In [None]:
import pandas as pd
import json
import re
import string
from datasets import Dataset
import torch
import numpy as np

DATA_PATH = "../../../arxiv/arxiv-metadata-oai-snapshot.json"
YEAR_PATTERN = r"(19|20[0-9]{2})"

In [None]:
import transformers

In [None]:
transformers.logging

## Load data

In [None]:
def clean_description(description: str):
    if not description:
        return ""
    # remove unicode characters
    description = description.encode('ascii', 'ignore').decode()

    # remove punctuation
    description = re.sub('[%s]' % re.escape(string.punctuation), ' ', description)

    # clean up the spacing
    description = re.sub('\s{2,}', " ", description)

    # remove urls
    #description = re.sub("https*\S+", " ", description)

    # remove newlines
    description = description.replace("\n", " ")

    # remove all numbers
    #description = re.sub('\w*\d+\w*', '', description)

    # split on capitalized words
    description = " ".join(re.split('(?=[A-Z])', description))

    # clean up the spacing again
    description = re.sub('\s{2,}', " ", description)

    # make all words lowercase
    description = description.lower()

    return description

# Generator functions that iterate through the file and process/load papers
def process(paper: dict):
    paper = json.loads(paper)
    if paper['journal-ref']:
        # Attempt to parse the date using Regex: this could be improved
        years = [int(year) for year in re.findall(YEAR_PATTERN, paper['journal-ref'])]
        years = [year for year in years if (year <= 2022 and year >= 1991)]
        year = min(years) if years else None
    else:
        year = None
    return {
        'id': paper['id'],
        'title': paper['title'],
        'year': year,
        'authors': paper['authors'],
        'categories': ','.join(paper['categories'].split(' ')),
        'abstract': paper['abstract'],    }

def papers():
    with open(DATA_PATH, 'r') as f:
        for paper in f:
            paper = process(paper)
            # Yield only papers that have a year I could process
            if paper['year']:
                yield paper


In [None]:
df = pd.DataFrame(papers())

In [None]:
df_sample = df.copy()
df_sample_clean = df.copy()

## Prepare text and labels

In [None]:
df_sample['text'] = df_sample.apply(lambda r: clean_description(r['title'] + ' ' + r['abstract']), axis=1).tolist()
df_sample = df_sample[['text', 'categories']]


In [None]:
df_tmp = df.copy()

In [None]:
df_tmp['categories'].str.get_dummies(sep=',')

In [None]:
# concatenate df_sample and dummies (ooe_df will be used to inverse the preds and get category names)
ooe_df = df_sample['categories'].str.get_dummies(sep=',')
num_classes = ooe_df.shape[1]

In [None]:
category_cols = ooe_df.columns.tolist()
parse_labels = lambda x : [x[c] for c in category_cols]

# parse the labels
df_sample['labels'] = ooe_df.apply(parse_labels, axis=1)
df_sample = df_sample[['text', 'labels']]

In [None]:
df_dataset = Dataset.from_pandas(df_sample)


## Modelling

In [None]:
from transformers import AutoModelForSequenceClassification
from transformers import AutoTokenizer
from transformers import Trainer, TrainingArguments

In [None]:
tokenizer = AutoTokenizer.from_pretrained("prajjwal1/bert-tiny",
                                          problem_type="multi_label_classification",
                                          model_max_length=512)

def tokenize_and_encode(examples):
  return tokenizer(examples["text"], truncation=True)
cols = df_dataset.column_names
cols.remove('labels')
df_dataset = df_dataset.map(tokenize_and_encode, batched=True, remove_columns=cols)

df_dataset.set_format("torch")
df_dataset = (df_dataset
          .map(lambda x : {"float_labels": x["labels"].to(torch.float)}, remove_columns=["labels", "token_type_ids"])
          .rename_column("float_labels", "labels"))

In [None]:
model = AutoModelForSequenceClassification.from_pretrained(
    "prajjwal1/bert-tiny",
    num_labels=num_classes,
    problem_type="multi_label_classification"
    )

In [None]:
args = TrainingArguments(
    save_strategy="epoch",
    num_train_epochs=1,
    output_dir = '.outputs',
    logging_steps = 10000
)

trainer = Trainer(model=model,
                  args=args,
                  train_dataset=df_dataset,
                  tokenizer=tokenizer)

In [None]:
trainer.train()

In [None]:
preds = trainer.predict(df_dataset)

## Get top category name from predictions

In [None]:
# Get top 3 predictions per paper
top_k_preds = torch.topk(
    torch.nn.functional.softmax(torch.tensor(preds.predictions)), 
    k=3, 
    dim=1
)
top_k_preds_confidence = top_k_preds.values
top_k_preds_idx = top_k_preds.indices

In [None]:
predictions = []
labels = ooe_df.columns
for article_pred in top_k_preds_idx.tolist():
    predictions.append([labels[pred_idx] for pred_idx in article_pred])

In [None]:
from typing import List


CONFIDENCE_THRESHOLD = 0.1

output = []

def build_prediction_string(prediction_labels: List[str], prediction_confidences: List[float], min_confidence: float) -> str:
    preds = []
    for label, conf in zip(prediction_labels, prediction_confidences):
        if conf < min_confidence:
            continue
        preds.append(f"{label}({conf:.4f})")
    
    if preds:
        return "|".join(preds)
    else:
        return ""
    
output = []
for pred_labels, pred_confidences in zip(predictions, top_k_preds_confidence):
    output.append(
        build_prediction_string(pred_labels, pred_confidences, CONFIDENCE_THRESHOLD)
    )

In [None]:
df["enriched_categories"] = output

In [None]:
df.to_pickle("../papers_df.pkl")