In [1]:
import os
if not os.path.exists('ArXiv-10.zip'):
  os.system('wget https://github.com/ashfarhangi/Protoformer/raw/main/data/ArXiv-10.zip')
  os.system('unzip ArXiv-10.zip')

In [2]:
import argparse

import pandas as pd
import os
import csv
import numpy as np
import re
import string

mypath = ''

In [3]:
def preprocess(text):
    
    if text == '':
        return text
    
    # remove extra commas and semi-colons
    text = text.replace(".,",".")
    text = text.replace("..",".")
    text = text.replace(".;",".")
    
    text = text.replace("!,","!")
    text = text.replace("!.","!")
    text = text.replace("!;","!")
    
    text = text.replace("?,","!")
    text = text.replace("?.","!")
    text = text.replace("?;","!")
    
    text = text.replace(". ,",".")
    text = text.replace(". .",".")
    text = text.replace(". ;",".")
    
    text = text.replace("! ,","!")
    text = text.replace("! .","!")
    text = text.replace("! ;","!")
    
    text = text.replace("? ,","!")
    text = text.replace("? .","!")
    text = text.replace("? ;","!")
    
    # Other common issues
    
    # remove first character if it is a punctuation
    if text[0] in string.punctuation:
        text = text[1:]
    
    # remove extra commas in text
    text = re.sub(r'[.]+[\n]+[,]',".\n", text)
    
    # remove extra semi-colons in text
    text = re.sub(r'[.]+[\n]+[;]',".\n", text)
    
    # Replace new line with space
    text = text.replace("\n"," ")
    
    # Replace tab with space
    text = text.replace("\t"," ")
    
    # Remove random new line + comma
    text = text.replace("\n,","")
    
    # Replace multiple spaces with a single space
    text = re.sub(' +', ' ', text)
    
    # Remove trailing characters if it does not end with .
    while text[-1] != '.' and len(text)>2:
        text = text[:-2]
        
    # Remove initial characters if they are space or puctuation
    while text[0] in string.punctuation or text[0] == ' ':
        text = text[1:]
        if text == '':
            return text
    
    return text

In [4]:
arxiv10 = pd.read_csv(mypath + 'arxiv100.csv')
arxiv10['abstract'] = arxiv10['abstract'].map(lambda x:preprocess(x))

In [5]:
arxiv10.head()

Unnamed: 0,title,abstract,label
0,The Pre-He White Dwarfs in Eclipsing Binaries....,We report the first $BV$ light curves and high...,astro-ph
1,A Possible Origin of kHZ QPOs in Low-Mass X-ra...,A possible origin of kHz QPOs in low-mass X-ra...,astro-ph
2,The effects of driving time scales on heating ...,Context. The relative importance of AC and DC ...,astro-ph
3,A new hard X-ray selected sample of extreme hi...,Extreme high-energy peaked BL Lac objects (EHB...,astro-ph
4,The baryon cycle of Seven Dwarfs with superbub...,"We present results from a high-resolution, cos...",astro-ph


In [6]:
arxiv10 = arxiv10.dropna()

In [7]:
arxiv10.dropna().shape

(100000, 3)

In [8]:
dataset = arxiv10
dataset = dataset.dropna()
dataset.to_csv(mypath +'dataset.csv',index=False)
dataset = pd.read_csv(mypath + 'dataset.csv')


while dataset.shape!=dataset.dropna().shape:
  dataset = dataset.dropna()
  dataset.to_csv(mypath +'dataset.csv',index=False)
  dataset = pd.read_csv(mypath + 'dataset.csv')

In [9]:
train, val, test = \
              np.split(dataset.sample(frac=1, random_state=2023), 
                       [int(.7*len(dataset)), int(.85*len(dataset))])

  return bound(*args, **kwds)


In [10]:
train.to_csv(mypath +'train.csv',index=False)
test.to_csv(mypath +'test.csv',index=False)
val.to_csv(mypath +'val.csv',index=False)

In [11]:
print(len(train))
print(len(val))
print(len(test))

67548
14475
14475


In [12]:
print(len(train.dropna()))
print(len(val.dropna()))
print(len(test.dropna()))

67548
14475
14475


Train/Test/Validation sets are quite balanced.

In [13]:
train.groupby(['label']).size()

label
astro-ph    6884
cond-mat    6849
cs          6924
eess        6884
hep-ph      6629
hep-th      6699
math        6348
physics     6720
quant-ph    6746
stat        6865
dtype: int64

In [14]:
test.groupby(['label']).size()

label
astro-ph    1529
cond-mat    1505
cs          1469
eess        1435
hep-ph      1428
hep-th      1391
math        1355
physics     1483
quant-ph    1477
stat        1403
dtype: int64

In [15]:
val.groupby(['label']).size()

label
astro-ph    1456
cond-mat    1414
cs          1431
eess        1507
hep-ph      1453
hep-th      1414
math        1305
physics     1525
quant-ph    1431
stat        1539
dtype: int64

In [16]:
train = pd.read_csv(mypath + 'train.csv')
test = pd.read_csv(mypath + 'test.csv')
val = pd.read_csv(mypath + 'val.csv')

In [17]:
import json
id2label = ['physics', 'hep-ph', 'eess', 'astro-ph', 'hep-th', 'quant-ph', 'stat', 'math', 'cond-mat', 'cs']
label2id = {'physics': 0, 'hep-ph': 1, 'eess': 2, 'astro-ph': 3, 'hep-th': 4, 'quant-ph': 5, 'stat': 6, 'math': 7, 'cond-mat': 8, 'cs': 9}

with open("id2label.json", "w") as f:
    json.dump(id2label, f)
    
with open("label2id.json", "w") as f:
    json.dump(label2id, f)

In [18]:
train['label'] = [label2id[key] for key in train['label']]
test['label'] = [label2id[key] for key in test['label']]
val['label'] = [label2id[key] for key in val['label']]

train = train.drop(['title'], axis = 1)
test = test.drop(['title'], axis = 1)
val = val.drop(['title'], axis = 1)

In [19]:
train.to_csv('train_abstract_ilabel.csv')
val.to_csv('val_abstract_ilabel.csv')
test.to_csv('test_abstract_ilabel.csv')

In [20]:
import json

label2id = {}
id2label = {}

with open("id2label.json", "r") as f:
    id2label = json.load(f)
    
with open("label2id.json", "r") as f:
    label2id = json.load(f)

print(id2label)
print(label2id)

['physics', 'hep-ph', 'eess', 'astro-ph', 'hep-th', 'quant-ph', 'stat', 'math', 'cond-mat', 'cs']
{'physics': 0, 'hep-ph': 1, 'eess': 2, 'astro-ph': 3, 'hep-th': 4, 'quant-ph': 5, 'stat': 6, 'math': 7, 'cond-mat': 8, 'cs': 9}


In [21]:
import torch
import torch.nn as nn

max_length = 512
padding = "max_length"
truncation = True
model_name = "bert-base-uncased"
num_epoch = 100
batch_size = 10
num_labels = 10
add_special_tokens=True
return_attention_mask=True
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)

cuda


In [22]:
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification


tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=num_labels).to(device)

  from .autonotebook import tqdm as notebook_tqdm
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.weight', 'classifier.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [23]:
from torch.utils.data import Dataset, DataLoader

class bert_dataset(Dataset):
    def __init__(self, data, tokenizer, max_length, add_special_tokens, padding, return_attention_mask, truncation):
        self.texts = data['abstract']
        self.labels = data['label']
        self.tokenizer = tokenizer
        input_ids = []
        attention_masks = []

        for text in self.texts:
            encoded_dict = tokenizer.encode_plus(
                text,
                add_special_tokens=add_special_tokens,
                max_length=max_length,
                padding=padding,
                return_attention_mask=return_attention_mask,
                return_tensors='pt',
                truncation=truncation
            )
            input_ids.append(encoded_dict['input_ids'])
            attention_masks.append(encoded_dict['attention_mask'])

        self.input_ids = torch.cat(input_ids, axis=0)
        self.attention_masks = torch.cat(attention_masks, axis=0)
        self.labels = torch.tensor(self.labels, dtype=torch.long)

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        return {'input_ids': self.input_ids[idx],
                'attention_mask': self.attention_masks[idx],
                'labels': self.labels[idx]}

In [26]:
train_data = pd.read_csv('train_abstract_ilabel.csv')
train_dataset = bert_dataset(train_data, tokenizer, max_length=max_length, add_special_tokens=add_special_tokens, 
                              padding=padding, return_attention_mask=return_attention_mask, truncation=truncation)
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

val_data = pd.read_csv('val_abstract_ilabel.csv')
val_dataset = bert_dataset(val_data, tokenizer, max_length=max_length, add_special_tokens=add_special_tokens, 
                              padding=padding, return_attention_mask=return_attention_mask, truncation=truncation)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True)

test_data = pd.read_csv('test_abstract_ilabel.csv')
test_dataset = bert_dataset(test_data, tokenizer, max_length=max_length, add_special_tokens=add_special_tokens, 
                              padding=padding, return_attention_mask=return_attention_mask, truncation=truncation)
test_dataloader = DataLoader(test_dataset, batch_size=4, shuffle=True)

KeyboardInterrupt: 

In [27]:
from transformers import AdamW, AutoModelForSequenceClassification
import wandb

optimizer = AdamW(model.parameters(), lr=2e-5)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.1)
# best_loss = float('inf')  # initialize with a very large value
wandb.init(entity='zmengaf', project='rcfda', name='arxiv_bert')
wandb.config.update({
    "max_length": max_length,
    "num_epoch": num_epoch,
    "batch_size": batch_size,
    "num_labels": num_labels,
})

max_val_acc = 0
for epoch in range(num_epoch):
# for epoch in range(4, 6):
    model.train()
    train_loss = 0
    train_correct = 0
    for batch_idx, batch in enumerate(train_dataloader):
        optimizer.zero_grad()
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)
        outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
        preds = torch.argmax(outputs.logits, dim=1)
        loss = outputs.loss
        loss.backward()
        optimizer.step()
        acc = (preds == labels).float().mean()
        train_loss += loss.item()
        train_correct += (outputs.logits.argmax(dim=1) == labels).float().sum().item()
    train_loss /= len(train_dataloader)
    train_acc = train_correct / len(train_dataset)
    print(f"Epoch {epoch}: train_loss = {train_loss:.4f}, train_acc = {train_acc:.4f}")
    
    model.eval()
    val_loss = 0
    val_correct = 0
    for batch in val_dataloader:
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)
        with torch.no_grad():
            outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
            loss = outputs.loss
            logits = outputs.logits
            preds = torch.argmax(logits, dim=1)
            val_loss += loss.item()
            val_correct += (outputs.logits.argmax(dim=1) == labels).float().sum().item()
            print(val_correct)
    val_loss /= len(val_dataloader)
    val_acc = val_correct / len(val_dataset)
    if val_acc > max_val_acc:
        max_val_acc = val_acc
        torch.save(model, f"bert_model_arxiv_acc_{max_val_acc}.pt")
    print(f"Epoch {epoch}: val_loss = {val_loss:.4f}, val_acc = {val_acc:.4f}")
    wandb.log({"epoch": epoch, "train_loss": train_loss, "train_acc": train_acc, "val_loss": val_loss, "val_acc": val_acc})





0,1
epoch,▁▂▃▅▆▇█
train_acc,▁▄▅▇▇██
train_loss,█▅▄▂▂▁▁
val_acc,▁▄▅▇▇██
val_loss,▂▁▂▃▅▆█

0,1
epoch,6.0
train_acc,9.48557
train_loss,0.14963
val_acc,9.48557
val_loss,0.18358


Epoch 0: train_loss = 0.1300, train_acc = 9.5537
8.0
16.0
21.0
29.0
38.0
47.0
53.0
61.0
70.0
78.0
84.0
93.0
103.0
110.0
118.0
126.0
135.0
144.0
152.0
162.0
170.0
178.0
186.0
192.0
201.0
209.0
218.0
227.0
233.0
241.0
248.0
258.0
267.0
273.0
280.0
289.0
297.0
306.0
314.0
322.0
332.0
339.0
348.0
356.0
365.0
373.0
379.0
387.0
396.0
404.0
413.0
421.0
428.0
435.0
443.0
452.0
459.0
467.0
474.0
482.0
489.0
498.0
506.0
515.0
522.0
529.0
538.0
545.0
554.0
562.0
570.0
578.0
587.0
597.0
605.0
614.0
622.0
630.0
637.0
645.0
654.0
659.0
666.0
673.0
679.0
684.0
691.0
701.0
708.0
715.0
724.0
733.0
740.0
747.0
755.0
763.0
773.0
781.0
787.0
795.0
802.0
811.0
819.0
827.0
834.0
839.0
848.0
856.0
863.0
871.0
881.0
889.0
895.0
903.0
910.0
917.0
927.0
936.0
945.0
953.0
960.0
966.0
973.0
978.0
984.0
994.0
1001.0
1009.0
1017.0
1026.0
1035.0
1044.0
1053.0
1062.0
1070.0
1078.0
1087.0
1094.0
1102.0
1110.0
1120.0
1128.0
1136.0
1143.0
1149.0
1155.0
1163.0
1170.0
1175.0
1185.0
1192.0
1199.0
1207.0
1215.0
1222.0
1229.

KeyboardInterrupt: 