# Multi-Label Section Classifier
https://github.com/abhimishra91/transformers-tutorials/blob/master/transformers_multi_label_classification.ipynb

In [None]:
# dependencies
import torch
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
import pandas as pd
import numpy as np
import json
import transformers
from sklearn import metrics
import matplotlib.pyplot as plt
from math import ceil
import re
import wandb

## Config

In [None]:
WANDB = False
RESUME = False

In [None]:
# set up the decice for GPU usage
torch_device = 'cuda:2' if torch.cuda.is_available() else 'cpu'

In [None]:
# pre-trained model
model_name = "bert-base-uncased" #["allenai/scibert_scivocab_uncased", "bert-base-uncased"]
model_id = "bert_1s_32_1" #for filenames
wandb_id = "bert_1s_32_1"
wandb_project = "section-clf-multisen"

In [None]:
# data import parameters
input_path = '/media/nvme3n1/proj_scisen/datasets/SciSections_sentences.jsonl'

context_width = 1 #number of sentences included in the context (surrounding the target sentence, incl. target sentence) (at least 1)
include_appendices = False #include appendices
included_conferences = ['as','a','b','c']

test_splits = [['as','a','b','c']] # [ [['as','a','b','c']], [['as'],['a'],['b','c']] ] test data split by conference rank

In [None]:
output_path_model = '/media/nvme3n1/proj_scisen/models/MLSC/' #directory to which models are saved
output_path_results = "/media/nvme3n1/proj_scisen/results/MLSC/" #directory to which further outputs are saved

In [None]:
validate = False #evaluate based on validation set (True for hyperparameter tuning) or test set
postprocess_labels = False
smooth = False

In [None]:
# hyperparameters
# manual search for both models

MAX_LEN = 512 #tokens
TRAIN_BATCH_SIZE = 32 #[16,32]
VALID_BATCH_SIZE = 32 #[16,32]
EPOCHS = 15 # with early stopping if no improvement for past 2 epichs
LEARNING_RATE = 1e-05 #[5e-05,3e-05,1e-05]

LAMBDAS = [0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2] #[all 0.1, 0.2, 0.3, 0.5]

In [None]:
if WANDB:
    wandb.init(project=wandb_project, resume = RESUME, name = wandb_id, config={"epochs": EPOCHS, "context_width": context_width, "validation": validate, "batch_size": TRAIN_BATCH_SIZE, "learning_rate": LEARNING_RATE, "lambdas": str(LAMBDAS), "trainingdata":"full", "conferences": "asabc"})

## Import and Preprocess the Data

In [None]:
#load raw data
paragraphList = list()
with open(input_path) as f:
    for paragraph in f:
        paragraphDict = json.loads(paragraph)
        paragraphList.append(paragraphDict)

In [None]:
#convert text labels for multi-label encoding and construct sentences & context
titles = ['introduction',
        'related work',
        'method',
        'experiment',
        'result',
        'discussion',
        'conclusion']
if include_appendices:
    titles.append('appendix')
n_sections = len(titles)

def multilabel(classes):
    classes_dict = {}
    for title in titles:
        classes_dict[title] = int(title in classes)
    return classes_dict

def context(sentences, sentence_idx, k_context):   
### method of choosing sentence context ###
#use only consecutive sentences within one paragraph (no [removed] tokens or end of paragraph):
#k=1: no context, k=2: predecessor+sentence, k = 3: predecessor+sentence+successor, k=2n: n predecessors+S+(n-1) sucessors, k=2n+1: n predecessors+S+n sucessors
#kick if context is empty or not of desired length or contains [removed] sentence
    k=k_context
    context = []
    start = sentence_idx - k//2
    stop = sentence_idx + k//2 if k%2 == 0 else sentence_idx + k//2 + 1
    if start >= 0 and stop <= len(sentences):
        context = sentences[start : stop]    
    else:
        return ""
    if "[removed]" in context:
        return ""
    return " ".join([str(item) for item in context])

In [None]:
inputList = []
for par in tqdm(paragraphList):
    if par['section_category'] == ['appendix'] and not include_appendices:
        continue
    if not par['rank'] in included_conferences:
        continue
    label = multilabel(par['section_category'])
    for idx, sentence in enumerate(par['sentences']):
        if not (sentence == "[removed]"):
            text_target = context(par['sentences'],idx,context_width) 
            inputList.append({**{'text_target': text_target},**label,**{'rank': par['rank']}})

In [None]:
df_input = pd.DataFrame(inputList)
df_input=df_input.mask(df_input == '')
df_input

In [None]:
#inspect: frequencies of sections
#sentences
df_input['label'] = df_input[df_input.columns[1:n_sections+1]].values.tolist()
print(df_input[~df_input['text_target'].isnull()]['label'].value_counts())

In [None]:
df_input['label'] = df_input[df_input.columns[1:n_sections+1]].values.tolist()
new_df = df_input[['text_target','label','rank']].copy()
new_df = new_df[~new_df['text_target'].isnull()]
new_df

In [None]:
#inspect: frequencies of ranks
print(new_df['rank'].value_counts())

In [None]:
#smaller datasets for (pre-)experiments
N = len(new_df)

#random subset: 10%
#new_df = new_df.sample(n=ceil(N*0.1), random_state = 42)
print(new_df['rank'].value_counts())
print(new_df['label'].value_counts())
new_df

## Prepare the Dataset and Dataloader

In [None]:
tokenizer = transformers.BertTokenizer.from_pretrained(model_name)

In [None]:
class CustomDataset(Dataset):
# to create training and validation dataset
# input: (BERT) tokenizer, dataframe, max_length
# output: tokenized outputs (ids, attention_mask, token_type_ids) and tags used for BERT training

    def __init__(self, dataframe, tokenizer, max_len):
        self.tokenizer = tokenizer
        self.data = dataframe
        self.text_target = dataframe.text_target
        #self.text_context = dataframe.text_context
        self.targets = dataframe.label
        #self.targets = self.data.llist
        self.max_len = max_len

    def __len__(self):
        return len(self.text_target)

    def __getitem__(self, index):
        text_target = str(self.text_target[index])
        text_target = " ".join(text_target.split())

        inputs = self.tokenizer.encode_plus(
            text_target,
            None,
            add_special_tokens=True,
            max_length=self.max_len,
            truncation=True,
            #pad_to_max_length=True,
            padding = "max_length",
            return_token_type_ids=True
        )
        ids = inputs['input_ids']
        mask = inputs['attention_mask']
        token_type_ids = inputs["token_type_ids"]


        return {
            'ids': torch.tensor(ids, dtype=torch.long),
            'mask': torch.tensor(mask, dtype=torch.long),
            'token_type_ids': torch.tensor(token_type_ids, dtype=torch.long),
            'targets': torch.tensor(self.targets[index], dtype=torch.float)
        }

In [None]:
# Creating the dataset

# split: 70/20/10
train_size = 0.7
valid_size = 0.2

#split into train, test:
trainvalid_df=new_df.sample(frac= train_size + valid_size ,random_state=200)
test_dataset=new_df.drop(trainvalid_df.index).reset_index(drop=True)
trainvalid_df = trainvalid_df.reset_index(drop=True)

#split into train, validation:
train_dataset=trainvalid_df.sample(frac= train_size/(train_size + valid_size) ,random_state=200)
valid_dataset=trainvalid_df.drop(train_dataset.index).reset_index(drop=True)
train_dataset = train_dataset.reset_index(drop=True)
if smooth:
    train_dataset["label"] = [[(lambda x: x-0.1 if x == 1 else x+0.1)(x) for x in labels] for labels in train_dataset["label"]]

#split test by conferences:
test_datasets = dict()
for split in test_splits:
    test_datasets[str(split)]=test_dataset[test_dataset['rank'].isin(split)].reset_index(drop=True)



print("FULL Dataset: {}".format(new_df.shape))
print("TRAIN Dataset: {}".format(train_dataset.shape))
print("VALIDATION Dataset: {}".format(valid_dataset.shape))
print("FULL TEST Dataset: {}".format(test_dataset.shape))

for split in test_splits:
    print("TEST Dataset ranks {}: {}".format(str(split), test_datasets[str(split)].shape))


training_set = CustomDataset(train_dataset, tokenizer, MAX_LEN) #70% of original data, used to fine tune the model
validation_set = CustomDataset(valid_dataset, tokenizer, MAX_LEN) #20% of original data, used to evaluate the performance of the model during training

testing_sets = dict() #10% of original data, used to evaluate the performance of the trained model
for split in test_splits:
    testing_sets[str(split)] = CustomDataset(test_datasets[str(split)], tokenizer, MAX_LEN)


# Creating the dataloader
#used for creating training and validation dataloader that load data to the neural network in a defined manner.
#This is needed because all the data from the dataset cannot be loaded to the memory at once, hence the amount of dataloaded to the memory and then passed to the neural network needs to be controlled.
#This control is achieved using the parameters such as batch_size and max_len

train_params = {'batch_size': TRAIN_BATCH_SIZE,
                'shuffle': True,
                'num_workers': 0
                }

test_params = {'batch_size': VALID_BATCH_SIZE,
                'shuffle': True,
                'num_workers': 0
                }

training_loader = DataLoader(training_set, **train_params)
validation_loader = DataLoader(validation_set, **test_params)
#testing_loader = DataLoader(testing_set, **test_params)

testing_loaders = dict()
for split in test_splits:
    testing_loaders[str(split)] = DataLoader(testing_sets[str(split)], **test_params)

## Create the Network for Fine Tuning

In [None]:
class BERTClass(torch.nn.Module):
# BERT model + Dropout (to regularise) + Linear (to classify) Layer
    def __init__(self):
        super(BERTClass, self).__init__()
        self.l1 = transformers.BertModel.from_pretrained(model_name, return_dict = False)
        self.l2 = torch.nn.Dropout(0.3)
        self.l3 = torch.nn.Linear(768, 7) # as many dimensions as categories #[7,2]

    def forward(self, ids, mask,token_type_ids):
        hidden_state, pooled_output = self.l1(ids, attention_mask=mask,token_type_ids=token_type_ids) #BERTModel layer
        output_2 = self.l2(pooled_output) #dropout layer
        output = self.l3(output_2) #linear layer
        return output # final layer output used to calculate loss and determine accuracy


model = BERTClass()
model.to(torch_device)
if WANDB:
    wandb.watch(model, log_freq = 50)

In [None]:
def loss_fn(outputs, targets):
# use final layer output to calculate loss and determine accuracy
    return torch.nn.BCEWithLogitsLoss()(outputs, targets) # binary cross-entropy with logits loss for multilabel classification

optimizer = torch.optim.Adam(params = model.parameters(), lr = LEARNING_RATE)

## Fine-Tune the Model

In [None]:
# Plot Val loss
def loss_plot(epochs, loss):
    plt.plot(epochs, loss, color='red', label='loss')
    plt.xlabel("epochs")
    plt.title("validation loss")
    plt.savefig(output_path_results+model_id+"_val_loss.png")
    
loss_vals = []

In [None]:
# Train Model

n_epochs_stop = 2
epochs_no_improve = 0
early_stop = False
min_val_loss = np.Inf
best_model_name = ""
final_epoch = 0
best_epoch = 0
start_epoch = 0

for epoch in range(start_epoch, EPOCHS):
    final_epoch += 1
    train_loss = 0
    valid_loss = 0
    model.train()
    for batch_idx,data in enumerate(tqdm(training_loader), 0): #dataloader passes data to the model based on batch size
        optimizer.zero_grad()
        
        #forward
        ids = data['ids'].to(torch_device, dtype = torch.long)
        mask = data['mask'].to(torch_device, dtype = torch.long)
        token_type_ids = data['token_type_ids'].to(torch_device, dtype = torch.long)
        targets = data['targets'].to(torch_device, dtype = torch.float)
        outputs = model(ids, mask, token_type_ids)
        
        #backward
        loss = loss_fn(outputs, targets) #output from the model and the actual category are compared to calculate the loss   
        loss.backward() #loss value is used to optimize the weights of the neurons in the network
        optimizer.step()
        train_loss = train_loss + ((1 / (batch_idx + 1)) * (loss.item() - train_loss))
        if WANDB:
                wandb.log({'epoch': epoch,
                            'loss': loss})
        
    #validate model
    model.eval()
    with torch.no_grad():
        for batch_idx, data in enumerate(validation_loader, 0):
            ids = data['ids'].to(torch_device, dtype = torch.long)
            mask = data['mask'].to(torch_device, dtype = torch.long)
            token_type_ids = data['token_type_ids'].to(torch_device, dtype = torch.long)
            targets = data['targets'].to(torch_device, dtype = torch.float)
            outputs = model(ids, mask, token_type_ids)

            loss = loss_fn(outputs, targets) # validation loss
            valid_loss = valid_loss + ((1 / (batch_idx + 1)) * (loss.item() - valid_loss))

        # calculate average losses
        train_loss = train_loss / len(training_loader)
        valid_loss = valid_loss / len(validation_loader)
        
        # print training/validation statistics
        print('Epoch: {} \tAverage Training Loss: {:.6f} \tAverage Validation Loss: {:.6f}'.format(
            epoch,
             train_loss,
            valid_loss
        ))
        loss_vals.append(valid_loss)
        if WANDB:
                wandb.log({'val_loss': valid_loss})
        
    #early stopping
    # reference: https://www.kaggle.com/code/akhileshrai/tutorial-early-stopping-vanilla-rnn-pytorch/notebook
    # If the validation loss is at a minimum
    if valid_loss < min_val_loss:
        #Save the model
        best_model_name = output_path_model+model_id+'_best_epoch'+str(epoch)
        best_epoch = epoch
        torch.save(model.state_dict(), best_model_name)
        epochs_no_improve = 0
        min_val_loss = valid_loss
    else:
        epochs_no_improve += 1
    if epochs_no_improve == n_epochs_stop: #and epoch > 2 
        print('Early stopping!' )
        early_stop = True
        break

In [None]:
if WANDB:
    wandb.config.update({"best_epochs": best_epoch+1})

In [None]:
# Plot Loss
loss_plot(np.linspace(1, final_epoch, final_epoch).astype(int), loss_vals)

In [None]:
model.load_state_dict(torch.load(best_model_name))
torch.save(model.state_dict(), output_path_model+model_id)
print("saved as "+output_path_model+model_id)

## Evaluate the Model

In [None]:
# Test Model

#model.load_state_dict(torch.load(best_model_name))
model.load_state_dict(torch.load(output_path_model+model_id))

if validate:
    eval_loaders = {"validation set": validation_loader}
else:
    eval_loaders = testing_loaders
    outputs_by_rank = {}
    targets_by_rank = {}

for evalset in eval_loaders:
    print(f"--------------- {evalset} ---------------")
    if (WANDB and len(eval_loaders)>1):
        wandb.init(project=wandb_project, resume = RESUME, name = wandb_id+"_"+re.sub('[\W_]+', '', evalset), config={"epochs": EPOCHS, "context_width": context_width, "validation": validate, "batch_size": TRAIN_BATCH_SIZE, "learning_rate": LEARNING_RATE, "lambdas": str(LAMBDAS), "trainingdata":"full", "conferences": re.sub('[\W_]+', '', evalset)})
    model.eval()
    fin_targets = []
    fin_outputs = []
    with torch.no_grad():
        for _, data in enumerate(tqdm(eval_loaders[evalset], 0)):
            ids = data['ids'].to(torch_device, dtype=torch.long)
            mask = data['mask'].to(torch_device, dtype=torch.long)
            token_type_ids = data['token_type_ids'].to(torch_device, dtype = torch.long)
            targets = data['targets'].to(torch_device, dtype=torch.float)
            outputs = model(ids, mask,token_type_ids)
            fin_targets.extend(targets.cpu().detach().numpy().tolist())
            fin_outputs.extend(torch.sigmoid(outputs).cpu().detach().numpy().tolist())

    outputs = fin_outputs
    targets = fin_targets    

    #outputs = np.array(outputs) >= LAMBDA
    outputs = np.array(outputs) >= LAMBDAS
    
    if postprocess_labels:
        for idx, output in enumerate(outputs):
            n_labels = sum(output)
            if not n_labels in [1,2]:
                fin_output = fin_outputs[idx][:]
                new_output = [False] * len(titles)
                if n_labels == 0:
                    new_output[fin_output.index(max(fin_output))] = True
                else: # n_labels > 2:
                    while (sum(new_output) < 2):
                        new_output[fin_output.index(max(fin_output))] = True
                        fin_output[fin_output.index(max(fin_output))] = 0
                outputs[idx] = new_output
    
    
    if not validate:
        outputs_by_rank[evalset] = outputs
        targets_by_rank[evalset] = targets
          
    accuracy = metrics.accuracy_score(targets, outputs)
    f1_score_avg = metrics.f1_score(targets, outputs, average='samples')
    f1_score_micro = metrics.f1_score(targets, outputs, average='micro')
    f1_score_macro = metrics.f1_score(targets, outputs, average='macro')


    print(f"Accuracy Score = {accuracy}")
    print(f"F1 Score (Samples) = {f1_score_avg}")
    print(f"F1 Score (Micro) = {f1_score_micro}")
    print(f"F1 Score (Macro) = {f1_score_macro}")
          
       
    if WANDB:
        wandb.config.update({"test data": evalset})
        wandb.log({"test/acc": accuracy,
                    "test/f1_samples": f1_score_avg,
                    "test/f1_micro": f1_score_micro,
                    "test/f1_macro": f1_score_macro})
      
    classification_report = metrics.classification_report(
    targets,
    outputs,
    output_dict=False,
    target_names= titles,
    digits = 4)

    with open(output_path_results+model_id+evalset+"_results.txt", "w") as f:
        print(f"F1 Score (Samples) = {f1_score_avg}",f"Accuracy Score = {accuracy}",f"F1 Score (Micro) = {f1_score_micro}",f"F1 Score (Macro) = {f1_score_macro}", file=f)
        print("--- Classification Report: ---")
        print(classification_report)

### further evaluation:

In [None]:
def label_totitles(labels):
    #display section names
    #labels should be list of int or convertible to int
    labelstring = str([int(l) for l in labels])
    labellist = labelstring.strip('][').split(', ')
    new_label = "" 
    for idx, label in enumerate(labellist):
        if (label == "1"):
            if (new_label == ""):
                new_label += titles[idx]
            else:
                new_label = new_label + ", " + titles[idx]
    return new_label 

df_labels = pd.DataFrame()
df_labels["label"] = [label_totitles(t) for t in targets]
df_labels["pred"] = [label_totitles(o) for o in outputs]
print(df_labels["pred"].value_counts())
#? df_labels["ids"] = ids

In [None]:
print(len(df_labels["label"]))
#print(len(df_labels[len(df_labels['label']) > 2]))
print(len(outputs))
niceos = [o for o in outputs if sum(o)<=2]
bados = [o for o in outputs if sum(o)>2]
print(len(niceos))
print(len(bados))
print(len(niceos)+len(bados))

len(bados)/len(outputs)

In [None]:
def scisensection(model, text_target, tokenizer, MAX_LEN):
    #get outputs for single sentence
    model.eval() 
    inputs = tokenizer.encode_plus(
            text_target,
            None,
            add_special_tokens=True,
            max_length=MAX_LEN,
            truncation=True,
            #pad_to_max_length=True,
            padding = "max_length",
            return_token_type_ids=True
        )
    ids = torch.tensor(inputs['input_ids'], dtype=torch.long).unsqueeze(0).to(torch_device, dtype=torch.long)
    mask = torch.tensor(inputs['attention_mask'], dtype=torch.long).unsqueeze(0).to(torch_device, dtype=torch.long)
    token_type_ids = torch.tensor(inputs['token_type_ids'], dtype=torch.long).to(torch_device, dtype=torch.long)
    with torch.no_grad():
        output = model(ids, mask, token_type_ids)
        fin_output = torch.sigmoid(output)
        fin_output = torch.sigmoid(output).cpu().detach().numpy().tolist()
    output= np.array(fin_output) >= LAMBDAS
    return label_totitles(output[0]), fin_output[0]

In [None]:
text_target = "Prior research has shown that this works."
scisensection(model, text_target, tokenizer, MAX_LEN)

### examples and confusion plots

In [None]:
# examples per category
# for each predicted label combination, print:
# number of true labels that have been classified here
# up to 5 example sentences incl. their true labels

unique_preds = list(set(df_labels["pred"]))

if validate:
    eval_dataset = valid_dataset
else:
    eval_dataset = test_dataset

for category in unique_preds:
    print(f"----- {category} -----")
    df_cat = df_labels[df_labels['pred'] == category]
    print(df_cat["label"].value_counts())
    df_cat = df_cat.sample(n = min(5, len(df_cat)))
    for idx in df_cat.index.values:
        print(f"~~~ true label: {df_cat.loc[idx]['label']} ~~~")
        #print(fin_outputs[idx])
        print(eval_dataset.iloc[idx]["text_target"])
    print()

In [None]:
# confusion plots

for title in set(df_labels["label"]):
    df_temp = df_labels[df_labels["label"] == title]
    crosstab = pd.crosstab(df_temp["label"], df_temp["pred"], rownames=["label"], colnames=["pred"])
    ax = crosstab.plot.bar(rot=0, figsize = (25, 10), width = 0.75)
    for container in ax.containers:
        ax.bar_label(container)