# Arash Hajian nezhad | BERT fine-tuning

#### Dependencies

In [None]:
# %%capture
# !pip install transformers==4.28.0

#### Imports

In [2]:
import ast

import numpy as np
import pandas as pd

import torch
from torch.utils.data import Dataset

from sklearn.model_selection import train_test_split
from sklearn.metrics import f1_score, roc_auc_score, accuracy_score

from transformers import AutoModelForSequenceClassification, AutoTokenizer, logging, TrainingArguments, Trainer, EvalPrediction
logging.set_verbosity_error()

#### Loading training dataframe

In [3]:
df = pd.read_csv('data/processed/stories_processed.csv')
df.head()

Unnamed: 0,body,labels
0,hello and welcome to BBC News a woman who gave...,"[1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]"
1,news now out of North Hollywood. A 14 yearold ...,"[0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]"
2,homelessness his city's greatest failure. That...,"[1, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]"
3,Minneapolis police officer Kim Potter guilty o...,"[1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]"
4,Judy an update now to the wildfires that wiped...,"[0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]"


#### Dataset class definition

In [4]:
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")


class StoriesDataset(Dataset):
    def __init__(self, df):
        super().__init__()

        self.data = [tokenizer(text, padding='max_length', max_length=512, truncation=True, return_tensors='pt') for text in df['body']]
        labels = [torch.tensor(ast.literal_eval(label)).float() for label in df['labels'].values]

        for i, data_point in enumerate(self.data):
            for key in data_point:
                self.data[i][key] = self.data[i][key].squeeze(0)  # squeezing the extra first dimension in input_ids, mask_ids and attention tensors

            self.data[i]['labels'] = labels[i]  # adding the labels

        self.__size = len(self.data)
    
    def __len__(self):
        return self.__size
    
    def __getitem__(self, idx):
        return self.data[idx]

#### Saving topics to labels mappings and vice versa

In [5]:
stories = pd.read_csv('data/processed/stories.csv')

stories = stories[stories['body'] != ' '].reset_index().drop(['index'], axis=1)

topics = []
for i in stories['topic']:
    current_topics = ast.literal_eval(i)
    for topic in current_topics:
        if topic not in topics:
            topics.append(topic)


topic_to_label = {topic: label for label, topic in enumerate(topics)}
label_to_topic = {label: topic for label, topic in enumerate(topics)}

#### Preparing model

In [6]:
model = AutoModelForSequenceClassification.from_pretrained('bert-base-uncased',
                                                           problem_type='multi_label_classification',
                                                           num_labels=15,
                                                           id2label=topic_to_label,
                                                           label2id=label_to_topic,
)

#### Defining some training metric for Train API usage

In [7]:
def multi_label_metrics(predictions, labels, sigmoid_threshold=0.5):
    """
    Function for computing metrics for usage in Trainer API.

    Args:
        predictions: model's output.
        labels: true labels of the input data.
        sigmoid_threshold: a `float` object representing the threshold which the label
                           would be counted as `1` after applying `sigmoid`.
    
    Returns:
        a dictionary containing some training metrics.
    """
    # applying sigmoid on model's predictions
    probabilities = torch.Tensor(predictions).sigmoid()

    # applying threshold
    y_pred = np.zeros(probabilities.shape)
    y_pred[np.where(probabilities >= sigmoid_threshold)] = 1

    # compute metrics
    y_true = labels
    f1_micro_average = f1_score(y_true=y_true, y_pred=y_pred, average='micro')
    roc_auc = roc_auc_score(y_true, y_pred, average='micro')
    accuracy = accuracy_score(y_true, y_pred)

    # return the results as a dictionary
    metrics = {'f1': f1_micro_average,
               'roc_auc': roc_auc,
               'accuracy': accuracy,
    }

    return metrics


def compute_metrics(p: EvalPrediction):
    """
    Function for fetching the training metrics.

    Args:
        p: an `EvalPrediction` object related to Trainer API.
    
    Returns:
        some training metrics.
    """
    predictions = p.predictions[0] if isinstance(p.predictions, tuple) else p.predictions
    metrics = multi_label_metrics(
        predictions=predictions, 
        labels=p.label_ids,
    )

    return metrics

#### Splitting dataset

In [8]:
df_train, df_valid = train_test_split(df, train_size=0.85, shuffle=True)
BATCH_SIZE = 16

#### Training arguments for Trainer API

In [9]:
trainer_args = TrainingArguments(
    f'bert-finetuned-stories-topic',
    evaluation_strategy = 'epoch',
    save_strategy = 'epoch',
    learning_rate=2e-5,
    per_device_train_batch_size=BATCH_SIZE,
    per_device_eval_batch_size=BATCH_SIZE // 2,
    num_train_epochs=5,
    weight_decay=0.01,
    load_best_model_at_end=True,
    metric_for_best_model='f1',
    push_to_hub=False,
)

#### Preparing the final Trainer API

In [10]:
trainer = Trainer(
    model,
    trainer_args,
    train_dataset=StoriesDataset(df_train),
    eval_dataset=StoriesDataset(df_valid),
    tokenizer=tokenizer,
    compute_metrics=compute_metrics
)

#### Training the model

In [11]:
trainer.train()



{'eval_loss': 0.19203034043312073, 'eval_f1': 0.5804962492787074, 'eval_roc_auc': 0.7140161528285432, 'eval_accuracy': 0.31953428201811124, 'eval_runtime': 24.3479, 'eval_samples_per_second': 31.748, 'eval_steps_per_second': 3.984, 'epoch': 1.0}
{'loss': 0.2355, 'learning_rate': 1.27007299270073e-05, 'epoch': 1.82}
{'eval_loss': 0.1526438593864441, 'eval_f1': 0.6862231534834992, 'eval_roc_auc': 0.7786105620691107, 'eval_accuracy': 0.4320827943078913, 'eval_runtime': 24.2918, 'eval_samples_per_second': 31.821, 'eval_steps_per_second': 3.993, 'epoch': 2.0}
{'eval_loss': 0.13900378346443176, 'eval_f1': 0.7315136476426798, 'eval_roc_auc': 0.8129281547006817, 'eval_accuracy': 0.48124191461837, 'eval_runtime': 24.7301, 'eval_samples_per_second': 31.258, 'eval_steps_per_second': 3.922, 'epoch': 3.0}
{'loss': 0.1361, 'learning_rate': 5.401459854014599e-06, 'epoch': 3.65}
{'eval_loss': 0.13401706516742706, 'eval_f1': 0.7452322738386309, 'eval_roc_auc': 0.8235018169487525, 'eval_accuracy': 0.495

TrainOutput(global_step=1370, training_loss=0.16619258797081718, metrics={'train_runtime': 2128.8787, 'train_samples_per_second': 10.285, 'train_steps_per_second': 0.644, 'train_loss': 0.16619258797081718, 'epoch': 5.0})

#### Functions for applying the predictions on the final dataframe
The output of the model is logits of shape `batch_size x 15`. Here we are feeding the model only one text at a time, hence the batch size would be 1, which we need to eliminate this redundant shape. We do this using PyTorch's `squeeze` method, which does exactly that. We must also take into account that the models outputs must be between 0 and 1, so we apply sigmoid function to do just that. (this has already been applied at the train time using `BCEWithLogitsLoss`)

In [20]:
def get_label_ids(logits: torch.Tensor, sigmoid_threshold: float = 0.5) -> list:
    """
    Function for getting transforming the model's output to its respective
    topic ids.

    Args:
        logits: a `torch.Tensor` object which is the output of the model.
        sigmoid_threshold: a `float` object representing the threshold which the label
                           would be counted as `1` after applying `sigmoid`.
    
    Returns:
        a list of topic ids based on the model's output and sigmoid threshold.
    """
    probabilities = logits.squeeze().cpu().sigmoid()
    predictions = np.zeros(probabilities.shape)
    predictions[np.where(probabilities > sigmoid_threshold)] = 1

    predicted_labels = str([str(label_to_topic[i]) for i, label in enumerate(predictions) if label == 1.0])
    return predicted_labels


def generate_topic(row: pd.DataFrame, sigmoid_threshold: float = 0.5) -> pd.DataFrame:
    """
    Function for applying the prediction of topics of the text bodies.

    Args:
        row: a row of a pandas dataframe.
        sigmoid_threshold: a `float` object representing the threshold which the label
                           would be counted as `1` after applying `sigmoid`.
    
    Returns:
        a row of a pandas dataframe containing a list of topics.
    """
    text = row['body']
    
    encoding = tokenizer(text, padding='max_length', max_length=512, truncation=True, return_tensors='pt')
    encoding = {key: value.to(trainer.model.device) for key, value in encoding.items()}

    model_output = model(**encoding).logits

    row['topic'] = get_label_ids(model_output, sigmoid_threshold=sigmoid_threshold)

    return row

#### Loading the final dataframe

In [21]:
final = pd.read_csv('data/processed/to_fill_proccessed.csv')
final.head()

Unnamed: 0,first_words,last_words,source_video_id,body,start,end
0,Well knew. This morning police need your help,"gunpoint, beating him and stealing his cell ph...",18246,Well knew. This morning police need your help ...,464928,505910
1,a call. San Francisco firefighters rescued a man,all the way down to the ocean.,12387,a call. San Francisco firefighters rescued a m...,359020,385526
2,"Paul. Meanwhile, the state set a record in","night through conservation, some 4000 conserva...",16859,"Paul. Meanwhile, the state set a record in ene...",60704,101238
3,Emergency crews in Florida continue to search for,"in Florida to more than 850,000 homes.",18246,Emergency crews in Florida continue to search ...,505290,534958
4,But even though the state never ordered rolling,feel since their power got cut out needlessly.,16859,But even though the state never ordered rollin...,100910,283606


#### Applying the predictions and checking it out

In [22]:
final = final.apply(generate_topic, axis=1)
final

Unnamed: 0,first_words,last_words,source_video_id,body,start,end,topic
0,Well knew. This morning police need your help,"gunpoint, beating him and stealing his cell ph...",18246,Well knew. This morning police need your help ...,464928,505910,['9ff54ded-904b-4e0c-85ce-a3617f5cb913']
1,a call. San Francisco firefighters rescued a man,all the way down to the ocean.,12387,a call. San Francisco firefighters rescued a m...,359020,385526,['9ff54ded-904b-4e0c-85ce-a3617f5cb913']
2,"Paul. Meanwhile, the state set a record in","night through conservation, some 4000 conserva...",16859,"Paul. Meanwhile, the state set a record in ene...",60704,101238,['83a09c6b-5f2f-421f-ae50-b38acca7e008']
3,Emergency crews in Florida continue to search for,"in Florida to more than 850,000 homes.",18246,Emergency crews in Florida continue to search ...,505290,534958,"['9ff54ded-904b-4e0c-85ce-a3617f5cb913', '9a06..."
4,But even though the state never ordered rolling,feel since their power got cut out needlessly.,16859,But even though the state never ordered rollin...,100910,283606,[]
5,"aid. And today, president Joe Biden and first",to view the destruction caused by Hurricane Ian.,18246,"aid. And today, president Joe Biden and first ...",546494,593818,['83a09c6b-5f2f-421f-ae50-b38acca7e008']
6,"In the last month, there have been numerous",are necessary to crack down on those hackers.,18246,"In the last month, there have been numerous da...",614910,700574,"['83a09c6b-5f2f-421f-ae50-b38acca7e008', '9ff5..."
7,and the warriors are playing the Boston Celtics,that. We'll see if they get it tonight.,12387,and the warriors are playing the Boston Celtic...,419994,649502,['b49207eb-96eb-4b73-b534-adc0ef85022a']
8,And San Leandro police searching for the person,footage to try to piece together more informat...,16859,And San Leandro police searching for the perso...,578612,619638,['9ff54ded-904b-4e0c-85ce-a3617f5cb913']
9,The updated Bivalent Coronavirus booster shot ...,on their vaccinations getting severe illness f...,16859,The updated Bivalent Coronavirus booster shot ...,619310,660654,['96326734-fd82-4350-b45c-513e7eb9147c']


We observe that some of the rows have no topics predicted for them. We will try lowering the `sigmoid_threshold` this time to see if it gets any better.

In [23]:
final = final.apply(generate_topic, sigmoid_threshold=0.3, axis=1)
final

Unnamed: 0,first_words,last_words,source_video_id,body,start,end,topic
0,Well knew. This morning police need your help,"gunpoint, beating him and stealing his cell ph...",18246,Well knew. This morning police need your help ...,464928,505910,['9ff54ded-904b-4e0c-85ce-a3617f5cb913']
1,a call. San Francisco firefighters rescued a man,all the way down to the ocean.,12387,a call. San Francisco firefighters rescued a m...,359020,385526,['9ff54ded-904b-4e0c-85ce-a3617f5cb913']
2,"Paul. Meanwhile, the state set a record in","night through conservation, some 4000 conserva...",16859,"Paul. Meanwhile, the state set a record in ene...",60704,101238,"['83a09c6b-5f2f-421f-ae50-b38acca7e008', '3982..."
3,Emergency crews in Florida continue to search for,"in Florida to more than 850,000 homes.",18246,Emergency crews in Florida continue to search ...,505290,534958,"['9ff54ded-904b-4e0c-85ce-a3617f5cb913', '9a06..."
4,But even though the state never ordered rolling,feel since their power got cut out needlessly.,16859,But even though the state never ordered rollin...,100910,283606,['39822b5f-e37e-43e8-b997-7142fe55c3ea']
5,"aid. And today, president Joe Biden and first",to view the destruction caused by Hurricane Ian.,18246,"aid. And today, president Joe Biden and first ...",546494,593818,"['83a09c6b-5f2f-421f-ae50-b38acca7e008', '9a06..."
6,"In the last month, there have been numerous",are necessary to crack down on those hackers.,18246,"In the last month, there have been numerous da...",614910,700574,"['83a09c6b-5f2f-421f-ae50-b38acca7e008', '9ff5..."
7,and the warriors are playing the Boston Celtics,that. We'll see if they get it tonight.,12387,and the warriors are playing the Boston Celtic...,419994,649502,['b49207eb-96eb-4b73-b534-adc0ef85022a']
8,And San Leandro police searching for the person,footage to try to piece together more informat...,16859,And San Leandro police searching for the perso...,578612,619638,['9ff54ded-904b-4e0c-85ce-a3617f5cb913']
9,The updated Bivalent Coronavirus booster shot ...,on their vaccinations getting severe illness f...,16859,The updated Bivalent Coronavirus booster shot ...,619310,660654,['96326734-fd82-4350-b45c-513e7eb9147c']


Much better now! This looks very good.

#### Saving the dataframe

In [24]:
final.to_csv('to_fill_finalized_BERT.csv', index=False)