## Fine-tuning BERT for multi-label text classification for research articles

In this notebook, we are going to fine-tune BERT to predict one or more labels for a given piece of text. Note that this notebook illustrates how to fine-tune a scibert_scivocab_uncased.

NEED OF EDITS
All of those work in the same way: they add a linear layer on top of the base model, which is used to produce a tensor of shape (batch_size, num_labels), indicating the unnormalized scores for a number of labels for every example in the batch.




## INSTALL LIBRARIES

In [13]:
# !pip install transformers torch nltk pandas scikit-learn

## MAKE NECESSARY IMPORTS

In [1]:
import numpy as np
import pandas as pd
import time
import datetime
import gc
import random
import re


import torch
import torch.nn as nn
from torch.utils.data import Subset
from torch.utils.data import TensorDataset, Dataset, DataLoader, RandomSampler, SequentialSampler,random_split

from sklearn.model_selection import train_test_split, KFold, StratifiedKFold
from sklearn.metrics import classification_report, f1_score

import transformers
from transformers import BertTokenizer, AutoTokenizer, BertModel, BertConfig, AutoModel, AdamW

import warnings
warnings.filterwarnings("ignore")

## PATH VARIABLES

In [3]:
TOKENIZER_SCIBERT_PATH="/kaggle/input/tokenizerscibert/"
MODEL_SCIBERT_PATH="/kaggle/input/modelscibert/"
TRAIN_DATASET_PATH="/kaggle/input/train-csv/train.csv"

## LOAD DATA
load train data set and also required tokenizer and scibert model so we can use it offine


In [5]:
# Load Dataset
df = pd.read_csv(TRAIN_DATASET_PATH)
df.head()

Unnamed: 0,Id,Title,Abstract,Categories
0,9707,Axiomatic Aspects of Default Inference,This paper studies axioms for nonmonotonic con...,['cs.LO']
1,24198,On extensions of group with infinite conjugacy...,We characterize the group property of being wi...,['math.GR']
2,35766,An Analysis of Complex-Valued CNNs for RF Data...,Recent deep neural network-based device classi...,"['cs.LG', 'cs.IT', 'eess.SP', 'math.IT']"
3,14322,On the reconstruction of the drift of a diffus...,The problem of reconstructing the drift of a d...,"['math.PR', 'math.ST', 'stat.TH']"
4,709,Three classes of propagation rules for GRS and...,"In this paper, we study the Hermitian hulls of...","['cs.IT', 'math.IT']"


In [6]:
df['text'] = df['Title'] +" "+ df['Abstract']
del df['Title']
del df['Abstract']
df['Categories'] = df['Categories'].str.replace(', ', ',')
df['Categories'] = df['Categories'].str.strip('[]')
categories_df = df['Categories'].str.get_dummies(sep=',')
df = pd.concat([df.drop('Categories', axis=1), categories_df], axis=1)
df.head()


Unnamed: 0,Id,text,'cs.AI','cs.AR','cs.CE','cs.CL','cs.CR','cs.CV','cs.DB','cs.DC',...,'q-fin.MF','q-fin.PM','q-fin.PR','q-fin.RM','q-fin.TR','stat.AP','stat.CO','stat.ME','stat.ML','stat.TH'
0,9707,Axiomatic Aspects of Default Inference This pa...,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
1,24198,On extensions of group with infinite conjugacy...,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
2,35766,An Analysis of Complex-Valued CNNs for RF Data...,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
3,14322,On the reconstruction of the drift of a diffus...,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,1
4,709,Three classes of propagation rules for GRS and...,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0


In [7]:


def clean_text(text):
    
    text = text.lower()
    
    text = re.sub(r"[^a-zA-Z?.!,¿]+", " ", text) # replacing everything with space except (a-z, A-Z, ".", "?", "!", ",")

    text = re.sub(r"http\S+", "",text) #Removing URLs 
    #text = re.sub(r"http", "",text)
    
    html=re.compile(r'<.*?>') 
    
    text = html.sub(r'',text) #Removing html tags
    
    punctuations = '@#!?+&*[]-%.:/();$=><|{}^' + "'`" + '_'
    for p in punctuations:
        text = text.replace(p,'') #Removing punctuations
        
    text = [word.lower() for word in text.split()]
    
    text = " ".join(text) #removing stopwords
    
    return text

In [8]:
df['text'] = df['text'].apply(lambda x: clean_text(x))

In [9]:
df.columns = df.columns.str.strip("'")
new_column_name = 'title_summary'  
df = df.rename(columns={df.columns[1]: new_column_name})
df.head()

Unnamed: 0,Id,title_summary,cs.AI,cs.AR,cs.CE,cs.CL,cs.CR,cs.CV,cs.DB,cs.DC,...,q-fin.MF,q-fin.PM,q-fin.PR,q-fin.RM,q-fin.TR,stat.AP,stat.CO,stat.ME,stat.ML,stat.TH
0,9707,axiomatic aspects of default inference this pa...,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
1,24198,on extensions of group with infinite conjugacy...,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
2,35766,an analysis of complex valued cnns for rf data...,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
3,14322,on the reconstruction of the drift of a diffus...,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,1
4,709,three classes of propagation rules for grs and...,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0


As we can see, the dataset contains 3 splits: one for training, one for validation and one for testing.

In [10]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [11]:
MAX_LEN = 512
TRAIN_BATCH_SIZE = 16
VALID_BATCH_SIZE = 16
EPOCHS = 3
LEARNING_RATE = 2e-5
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_SCIBERT_PATH)

In [12]:
target_cols = [col for col in df.columns if col not in ['Id', 'title_summary']]
target_cols

['cs.AI',
 'cs.AR',
 'cs.CE',
 'cs.CL',
 'cs.CR',
 'cs.CV',
 'cs.DB',
 'cs.DC',
 'cs.DM',
 'cs.GT',
 'cs.IR',
 'cs.IT',
 'cs.LG',
 'cs.LO',
 'cs.NI',
 'cs.OS',
 'cs.PL',
 'cs.RO',
 'cs.SD',
 'cs.SE',
 'econ.EM',
 'econ.GN',
 'econ.TH',
 'eess.AS',
 'eess.IV',
 'eess.SP',
 'math.AC',
 'math.AP',
 'math.AT',
 'math.CO',
 'math.CV',
 'math.GR',
 'math.IT',
 'math.LO',
 'math.NT',
 'math.PR',
 'math.QA',
 'math.ST',
 'q-bio.BM',
 'q-bio.CB',
 'q-bio.GN',
 'q-bio.MN',
 'q-bio.NC',
 'q-bio.TO',
 'q-fin.CP',
 'q-fin.EC',
 'q-fin.GN',
 'q-fin.MF',
 'q-fin.PM',
 'q-fin.PR',
 'q-fin.RM',
 'q-fin.TR',
 'stat.AP',
 'stat.CO',
 'stat.ME',
 'stat.ML',
 'stat.TH']

In [13]:
class BERTDataset(Dataset):
    def __init__(self, df, tokenizer, max_len):
        self.df = df
        self.max_len = max_len
        self.text = df.title_summary
        self.tokenizer = tokenizer
        self.targets = df[target_cols].values
        
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, index):
        text = self.text[index]
        inputs = self.tokenizer.encode_plus(
            text,
            truncation=True,
            add_special_tokens=True,
            max_length=self.max_len,
            padding='max_length',
            return_token_type_ids=True
        )
        ids = inputs['input_ids']
        mask = inputs['attention_mask']
        token_type_ids = inputs["token_type_ids"]
        
        return {
            'ids': torch.tensor(ids, dtype=torch.long),
            'mask': torch.tensor(mask, dtype=torch.long),
            'token_type_ids': torch.tensor(token_type_ids, dtype=torch.long),
            'targets': torch.tensor(self.targets[index], dtype=torch.float)
        }

In [14]:
df[target_cols].values

array([[0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0],
       ...,
       [0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0]])

In [15]:
train_dataset = BERTDataset(df, tokenizer, MAX_LEN)

In [16]:
class BERTClass(torch.nn.Module):
    def __init__(self):
        super(BERTClass, self).__init__()
        self.roberta = model = AutoModel.from_pretrained(MODEL_SCIBERT_PATH)
#         self.l2 = torch.nn.Dropout(0.3)
        self.fc = torch.nn.Linear(768,57)
    
    def forward(self, ids, mask, token_type_ids):
        _, features = self.roberta(ids, attention_mask = mask, token_type_ids = token_type_ids, return_dict=False)
#         output_2 = self.l2(output_1)
        output = self.fc(features)
        return output

In [17]:
def loss_fn(outputs, targets):
    return torch.nn.BCEWithLogitsLoss()(outputs, targets)

In [18]:
def find_best_threshold(y_true, y_pred_probs):
    best_threshold = 0.0
    best_f1_score = 0.0
    
    for threshold in np.arange(-2.0, 2.0, 0.05):
        y_pred = (y_pred_probs >= threshold).astype(int)
        f1 = f1_score(y_true, y_pred, average='macro')
        
        if f1 > best_f1_score:
            best_f1_score = f1
            best_threshold = threshold
            
    return best_threshold, best_f1_score

In [19]:
def train(epoch):
    kf = KFold(n_splits=3, shuffle=True, random_state = 42)
    results = []
    final_threshold = 0
    
    for fold, (train_ids, val_ids) in enumerate(kf.split(train_dataset)):
        
        model = BERTClass()
        model.to(device)
        optimizer = AdamW(params =  model.parameters(), lr=LEARNING_RATE, weight_decay=1e-6)
        train_subs = Subset(train_dataset, train_ids)
        val_subs = Subset(train_dataset, val_ids)
        train_loader = DataLoader(train_subs, batch_size= TRAIN_BATCH_SIZE, shuffle=True)
        validation_loader = DataLoader(val_subs, batch_size= VALID_BATCH_SIZE, shuffle=False)

        model.train()
        for i in range(epoch):
            for j, data in enumerate(train_loader):
                ids = data['ids'].to(device, dtype = torch.long)
                mask = data['mask'].to(device, dtype = torch.long)
                token_type_ids = data['token_type_ids'].to(device, dtype = torch.long)
                targets = data['targets'].to(device, dtype = torch.float)
                outputs = model(ids, mask, token_type_ids)
                loss = loss_fn(outputs, targets)
                if j%50 == 0:
                    print(f'Epoch: {i}, batch : {j} Loss:  {loss.item()}')
                loss.backward()
                optimizer.step()
                optimizer.zero_grad()
                
        model.eval()
        predictions = []
        targets = []
        with torch.no_grad():
            for i, data in enumerate(validation_loader):
                ids = data['ids'].to(device, dtype = torch.long)
                mask = data['mask'].to(device, dtype = torch.long)
                token_type_ids = data['token_type_ids'].to(device, dtype = torch.long)
                targets.extend(data['targets'])
                outputs = model(ids, mask, token_type_ids)
                batch_predictions = outputs.squeeze().tolist()
                predictions.extend(batch_predictions)
            pred = np.array(predictions)
            threshold, f1_score = find_best_threshold(targets, pred)
            final_threshold += threshold/3
            results.append(f1_score)
        print(f"f1_score for fold {fold} : {f1_score}")
        
    model = BERTClass()
    model.to(device)
    optimizer = AdamW(params =  model.parameters(), lr=LEARNING_RATE, weight_decay=1e-6)
    train_loader = DataLoader(train_dataset, batch_size= TRAIN_BATCH_SIZE, shuffle=True)
    model.train()
    for i in range(epoch):
        for j, data in enumerate(train_loader):
            ids = data['ids'].to(device, dtype = torch.long)
            mask = data['mask'].to(device, dtype = torch.long)
            token_type_ids = data['token_type_ids'].to(device, dtype = torch.long)
            targets = data['targets'].to(device, dtype = torch.float)

            outputs = model(ids, mask, token_type_ids)

            loss = loss_fn(outputs, targets)
            if j%50 == 0:
                print(f'Epoch: {i}, batch : {j} Loss:  {loss.item()}')
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
    return model, results, final_threshold

In [20]:
epoch=3

In [22]:
model, results, final_threshold = train(epoch)
torch.save(model.state_dict(), 'model.bin')

Epoch: 0, batch : 0 Loss:  0.6948950290679932
Epoch: 0, batch : 50 Loss:  0.2813558876514435
Epoch: 0, batch : 100 Loss:  0.19365939497947693
Epoch: 0, batch : 150 Loss:  0.14521071314811707
Epoch: 0, batch : 200 Loss:  0.15278439223766327
Epoch: 0, batch : 250 Loss:  0.13711890578269958
Epoch: 0, batch : 300 Loss:  0.11356864124536514
Epoch: 0, batch : 350 Loss:  0.1287909597158432
Epoch: 0, batch : 400 Loss:  0.11305023729801178
Epoch: 0, batch : 450 Loss:  0.11815514415502548
Epoch: 0, batch : 500 Loss:  0.0955762192606926
Epoch: 0, batch : 550 Loss:  0.10760829597711563
Epoch: 0, batch : 600 Loss:  0.11473220586776733
Epoch: 0, batch : 650 Loss:  0.09339188784360886
Epoch: 0, batch : 1350 Loss:  0.06779580563306808
Epoch: 0, batch : 1400 Loss:  0.09142578393220901
Epoch: 0, batch : 1450 Loss:  0.07584300637245178
Epoch: 0, batch : 1500 Loss:  0.07085160911083221
Epoch: 0, batch : 1550 Loss:  0.06282248347997665
Epoch: 0, batch : 1600 Loss:  0.04811670258641243
Epoch: 0, batch : 165

In [23]:
results

[0.6687139526017833, 0.6534781469319657, 0.6645800427910645]

In [24]:
final_threshold

-1.0499999999999992