# Fine-tune ULMFiT + sentencepiece langauge model on arxiv abstracts

This notebook contains code for fine-tuning a language model using ULMFiT with sentencepiece unigram tokenization model. Both the tokenizator and language model were trained on corpus of 64K+ machine learning papers. In this notebook we fine-tune on arxiv data using only titles and abstracts. We use papers published before 2020 as a training set and after 2020 as a validation set, excluding arxiv test set from both sets.

In [1]:
%cd ~/paperswithcode/paper-extractor

/home/ubuntu/paperswithcode/paper-extractor


In [2]:
import pandas as pd, numpy as np
from pathlib import Path

DATA_PATH = Path("notebooks/shared-notebooks/arxiv-class")
TRAIN_PATH = DATA_PATH / "arxiv-tag-classifier-data.json"
TEST_PATH = DATA_PATH / "classifier.tsv"

In [3]:
train_df = pd.read_pickle(DATA_PATH / "train_df.pkl.gz")
valid_df = pd.read_pickle(DATA_PATH / "valid_df.pkl.gz")

In [7]:
from fastai.text import *

BASE_DIR = Path("./models/ulmfit_baseline")
VOCAB_PATH = BASE_DIR / "data_lm_export_vocab.pkl"
MODELS_PATH = DATA_PATH / "models"

processor = SPProcessor(sp_model=BASE_DIR / "tmp" / "spm.model", sp_vocab=BASE_DIR / "tmp" / "spm.vocab", n_cpus=8, mark_fields=True)
vocab = Vocab.load(VOCAB_PATH)

In [8]:
text_cols = ["title", "abstract"]
valid_tl = TextList.from_df(valid_df, MODELS_PATH, cols=text_cols, processor=processor)

In [9]:
train_tl = TextList.from_df(train_df, MODELS_PATH, cols=text_cols, processor=processor)

In [10]:
data_lm = ItemLists(MODELS_PATH, train_tl, valid_tl)\
    .label_for_lm()\
    .databunch(bs=256)

In [11]:
data_lm.save('data_lm_abs.pkl')

  "type " + obj.__name__ + ". It won't be checked "


In [12]:
learn = language_model_learner(data_lm, AWD_LSTM, drop_mult=0.5)
learn.fit_one_cycle(1, 1e-2)

epoch,train_loss,valid_loss,accuracy,time
0,3.729671,3.633713,0.357466,1:37:35


In [13]:
learn.unfreeze()
learn.fit_one_cycle(1, 1e-3)

epoch,train_loss,valid_loss,accuracy,time
0,3.063642,3.065155,0.424697,1:56:45


In [14]:
learn.save_encoder("arxiv_enc_sp30k_1_1_abstracts.pkl")