# SciBERT Training

## Import packages and read data

In [2]:
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import json
import torch
import re
import seaborn as sns
from tqdm.notebook import tqdm
from sklearn.metrics import f1_score, confusion_matrix, classification_report
from transformers import BertTokenizer, BertModel, get_linear_schedule_with_warmup
from collections import defaultdict, Counter
from sklearn.preprocessing import LabelEncoder
from torch.optim import AdamW


sns.set_theme()


# Torch
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

In [10]:
train_df = pd.read_json(path_or_buf=r'./drive/MyDrive/scicite/train.jsonl', lines=True)
test_df = pd.read_json(path_or_buf=r'./drive/MyDrive/scicite/test.jsonl', lines=True)
val_df = pd.read_json(path_or_buf=r'./drive/MyDrive/scicite/dev.jsonl', lines=True)
train_df

Unnamed: 0,source,citeEnd,sectionName,citeStart,string,label,label_confidence,citingPaperId,citedPaperId,isKeyCitation,id,unique_id,excerpt_index,label2,label2_confidence
0,explicit,175.0,Introduction,168.0,"However, how frataxin interacts with the Fe-S ...",background,1.0000,1872080baa7d30ec8fb87be9a65358cd3a7fb649,894be9b4ea46a5c422e81ef3c241072d4c73fdc0,True,1872080baa7d30ec8fb87be9a65358cd3a7fb649>894be...,1872080baa7d30ec8fb87be9a65358cd3a7fb649>894be...,11,,
1,explicit,36.0,Novel Quantitative Trait Loci for Seminal Root...,16.0,"In the study by Hickey et al. (2012), spikes w...",background,1.0000,ce1d09a4a3a8d7fd3405b9328f65f00c952cf64b,b6642e19efb8db5623b3cc4eef1c5822a6151107,True,ce1d09a4a3a8d7fd3405b9328f65f00c952cf64b>b6642...,ce1d09a4a3a8d7fd3405b9328f65f00c952cf64b>b6642...,2,,
2,explicit,228.0,Introduction,225.0,"The drug also reduces catecholamine secretion,...",background,1.0000,9cdf605beb1aa1078f235c4332b3024daa8b31dc,4e6a17fb8d7a3cada601d942e22eb5da6d01adbd,False,9cdf605beb1aa1078f235c4332b3024daa8b31dc>4e6a1...,9cdf605beb1aa1078f235c4332b3024daa8b31dc>4e6a1...,0,,
3,explicit,110.0,Discussion,46.0,By clustering with lowly aggressive close kin ...,background,1.0000,d9f3207db0c79a3b154f3875c9760cc6b056904b,2cc6ff899bf17666ad35893524a4d61624555ed7,False,d9f3207db0c79a3b154f3875c9760cc6b056904b>2cc6f...,d9f3207db0c79a3b154f3875c9760cc6b056904b>2cc6f...,3,,
4,explicit,239.0,Discussion,234.0,Ophthalmic symptoms are rare manifestations of...,background,1.0000,88b86556857f4374842d2af2e359576806239175,a5bb0ff1a026944d2a47a155462959af2b8505a8,False,88b86556857f4374842d2af2e359576806239175>a5bb0...,88b86556857f4374842d2af2e359576806239175>a5bb0...,2,,
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
8238,explicit,50.0,,28.0,"Importantly, the results of Pascalis et al. (2...",background,0.7350,6f68ccd37718366c40ae6aeedf0b935bf560b215,60ed4bdabf92b2fbd6162dbd8979888cccca55d7,True,6f68ccd37718366c40ae6aeedf0b935bf560b215>60ed4...,6f68ccd37718366c40ae6aeedf0b935bf560b215>60ed4...,15,,
8239,explicit,182.0,DISCUSSION,179.0,"As suggested by Nguena et al, there is a need ...",background,0.7508,f2a1c1704f9587c94ed95bc98179dc499e933f5e,574e659da7f6c62c07bfaaacd1f31d65bd75524c,True,f2a1c1704f9587c94ed95bc98179dc499e933f5e>574e6...,f2a1c1704f9587c94ed95bc98179dc499e933f5e>574e6...,1,,
8240,explicit,120.0,DISCUSSION,108.0,Skeletal muscle is also a primary site of dise...,background,1.0000,18c97ea2ff60c110cc2a523e0fdf729608cbb083,fc13b9c3dfcc121013edaa12fa8ce7842aaed21a,False,18c97ea2ff60c110cc2a523e0fdf729608cbb083>fc13b...,18c97ea2ff60c110cc2a523e0fdf729608cbb083>fc13b...,8,,
8241,explicit,221.0,,185.0,ACTIVATION OF TRANSCRIPTION FACTORS Roles for ...,method,,4ec9b89857c0b27e8a4bd3745b7358f387773527,81affdba19e38e2b17cf7b9e93792cc2028cf21d,True,4ec9b89857c0b27e8a4bd3745b7358f387773527>81aff...,4ec9b89857c0b27e8a4bd3745b7358f387773527>81aff...,0,,


## Setup

In [11]:
max_length = 100
feature_name = "string"

# Tokenizer
MODEL_NAME = 'allenai/scibert_scivocab_uncased'
tokenizer = BertTokenizer.from_pretrained(MODEL_NAME, do_lower_case=True)

X_train = train_df[feature_name].values
X_train_encoded = tokenizer.batch_encode_plus(X_train.tolist(),
                                              pad_to_max_length=True,
                                              max_length=max_length,
                                              return_tensors='pt',
                                              truncation=True)
X_test = test_df[feature_name].values
X_test_encoded = tokenizer.batch_encode_plus(X_test.tolist(),
                                              pad_to_max_length=True,
                                              max_length=max_length,
                                              return_tensors='pt',
                                              truncation=True)
X_val = val_df[feature_name].values
X_val_encoded = tokenizer.batch_encode_plus(X_val.tolist(),
                                              pad_to_max_length=True,
                                              max_length=max_length,
                                              return_tensors='pt',
                                              truncation=True)

# Label Encoder
label_encoder = LabelEncoder()
y_train = label_encoder.fit_transform(train_df['label'].values)
y_train = torch.tensor(y_train)
# print(y_train.unique())
y_test = label_encoder.transform(test_df['label'].values)
y_test = torch.tensor(y_test)
y_val = label_encoder.transform(val_df['label'].values)
y_val = torch.tensor(y_val)



In [12]:
# Create dataset
batch_size = 32
train_dataset = torch.utils.data.TensorDataset(X_train_encoded['input_ids'], X_train_encoded['attention_mask'], y_train)
train_dataloader = torch.utils.data.DataLoader(train_dataset,
                                               sampler=torch.utils.data.RandomSampler(train_dataset),
                                               batch_size=batch_size)
test_dataset = torch.utils.data.TensorDataset(X_test_encoded['input_ids'], X_test_encoded['attention_mask'], y_test)
test_dataloader = torch.utils.data.DataLoader(test_dataset,
                                              sampler=torch.utils.data.SequentialSampler(test_dataset),
                                              batch_size=batch_size)
val_dataset = torch.utils.data.TensorDataset(X_val_encoded['input_ids'], X_val_encoded['attention_mask'], y_val)
val_dataloader = torch.utils.data.DataLoader(val_dataset,
                                              sampler=torch.utils.data.SequentialSampler(val_dataset),
                                              batch_size=batch_size)

In [13]:
class SciBERTClassifier(torch.nn.Module):
    def __init__(self, dropout_rate=0.3):
        super(SciBERTClassifier, self).__init__()

        self.SciBERT = BertModel.from_pretrained('allenai/scibert_scivocab_uncased')
        self.d1 = torch.nn.Dropout(dropout_rate)
        self.l1 = torch.nn.Linear(768, 64)
        self.bn1 = torch.nn.LayerNorm(64)
        self.d2 = torch.nn.Dropout(dropout_rate)
        self.l2 = torch.nn.Linear(64, 3)

    def forward(self, input_ids, attention_mask):
        x = self.SciBERT(input_ids=input_ids, attention_mask=attention_mask)
        x = x[1]
        x = self.d1(x)
        x = self.l1(x)
        x = self.bn1(x)
        x = torch.nn.Tanh()(x)
        x = self.d2(x)
        x = self.l2(x)
        #print(x.shape)
        return x


model = SciBERTClassifier(0.4)
model = model.to(device)

## Train

In [25]:
def evaluate(model, val_dataloader, val_size):
    model.eval()
    with torch.no_grad():
        val_loss = 0
        y_pred = []
        y_true = []
        for batch in val_dataloader:
            input_ids = batch[0].to(device)
            attention_masks = batch[1].to(device)
            labels = batch[2].type(torch.LongTensor).to(device)

            logits = model(input_ids=input_ids,
                            attention_mask=attention_masks,
                          )

            loss = torch.nn.CrossEntropyLoss()(logits, labels)
            val_loss += loss.item()

            y_pred.append(torch.max(logits, dim=-1)[1].detach().cpu().numpy())
            y_true.append(labels.detach().cpu().numpy())
        # Store train and validation loss history
        y_pred = np.concatenate(y_pred)
        y_true = np.concatenate(y_true)
        val_loss = val_loss / val_size
        val_f1 = f1_score(y_true, y_pred, average="macro")
    model.train()
    return val_loss, val_f1, y_true, y_pred


def train(model,
          optimizer,
          train_dataloader,
          val_dataloader,
          scheduler = None,
          num_epochs = 5,
         ):

    # Initialize losses and loss histories
    train_loss = 0

    train_loss_list = []
    val_loss_list = []
    train_f1_list = []
    val_f1_list = []

    best_val_f1 = 0

    train_size = len(train_dataloader)
    val_size = len(val_dataloader)

    model.train()

    # Train loop
    for epoch in range(num_epochs):
        y_pred = []
        y_true = []
        for batch in tqdm(train_dataloader):
            input_ids = batch[0].to(device)
            attention_masks = batch[1].to(device)
            labels = batch[2].type(torch.LongTensor).to(device)

            logits = model(input_ids=input_ids,
                           attention_mask=attention_masks,
                          )

            loss = torch.nn.CrossEntropyLoss()(logits, labels)

            loss.backward()

            # Optimizer and scheduler step
            optimizer.step()
            if scheduler:
                scheduler.step()

            optimizer.zero_grad()

            train_loss += loss.item()

            y_pred.append(torch.max(logits, dim=-1)[1].detach().cpu().numpy())
            y_true.append(labels.detach().cpu().numpy())

        y_pred = np.concatenate(y_pred)
        y_true = np.concatenate(y_true)
        train_loss = train_loss / train_size
        train_loss_list.append(train_loss)
        train_f1 = f1_score(y_true, y_pred, average="macro")
        train_f1_list.append(train_f1)

        # Validation
        val_loss, val_f1, _, _ = evaluate(model, val_dataloader, val_size)
        val_loss_list.append(val_loss)
        val_f1_list.append(val_f1)

        # Print summary
        print(f"Epoch {epoch}:")
        print(f"Train loss: {train_loss:.2f}, Validation loss: {val_loss:.2f}")
        print(f"Train Macro F1: {train_f1:.2f}, Validation Macro F1: {val_f1:.2f}")


        # checkpoint
        if val_f1 > best_val_f1:
            best_val_f1 = val_f1
            torch.save(model, 'scibert_best.pt')
            print(f"Model saved at epoch {epoch}.")

        train_loss = 0

        model.train()

    print('Training done!')

In [None]:
NUM_EPOCHS = 10 #15
print("======================= Start training =================================")
optimizer = torch.optim.AdamW(model.parameters(), lr = 2e-5, betas=(0.9, 0.98), eps=1e-6)
# scheduler = get_linear_schedule_with_warmup(optimizer,
#                                             num_warmup_steps=10,
#                                             num_training_steps=len(train_dataloader) * NUM_EPOCHS)

train(model=model,
      train_dataloader=train_dataloader,
      val_dataloader=val_dataloader,
      optimizer=optimizer,
      # scheduler=scheduler,
      num_epochs=NUM_EPOCHS)