In [1]:
import torch
import pandas as pd
from transformers import BertForSequenceClassification, BertTokenizerFast, Trainer, TrainingArguments
from torch.utils.data import Dataset
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
import os
import wandb

wandb.init(
    project="impossible-querry-pipeline",
    name="gemini-bert"
)

device = 'cuda' if torch.cuda.is_available() else 'cpu'
device

  from .autonotebook import tqdm as notebook_tqdm
[34m[1mwandb[0m: Using wandb-core as the SDK backend. Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mmehdinejjar[0m ([33mmehdinejjar-al-akhawayn-university[0m). Use [1m`wandb login --relogin`[0m to force relogin


'cuda'

In [2]:
checkpoint = "bert-base-cased"

data_path = "../dataset/dataset-match-score.csv"
df = pd.read_csv(data_path)
df['text'] = df['text'].apply(lambda x: x.replace('\n', ' ').replace('*', ''))


def compute_metrics(pred):
    """
    Computes accuracy, F1, precision, and recall for a given set of predictions.

    Args:
        pred (obj): An object containing label_ids and predictions attributes.
            - label_ids (array-like): A 1D array of true class labels.
            - predictions (array-like): A 2D array where each row represents
              an observation, and each column represents the probability of
              that observation belonging to a certain class.

    Returns:
        dict: A dictionary containing the following metrics:
            - Accuracy (float): The proportion of correctly classified instances.
            - F1 (float): The macro F1 score, which is the harmonic mean of precision
              and recall. Macro averaging calculates the metric independently for
              each class and then takes the average.
            - Precision (float): The macro precision, which is the number of true
              positives divided by the sum of true positives and false positives.
            - Recall (float): The macro recall, which is the number of true positives
              divided by the sum of true positives and false negatives.
    """
    labels = pred.label_ids

    preds = pred.predictions.argmax(-1)

    precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average='macro')

    acc = accuracy_score(labels, preds)

    return {
        'Accuracy': acc,
        'F1': f1,
        'Precision': precision,
        'Recall': recall
    }

labels = df['labels'].unique().tolist()
labels.sort()
id2label = {i: label for i, label in enumerate(labels)}
label2id = {label: i for i, label in enumerate(labels)}


train_texts, val_texts, train_labels, val_labels = train_test_split(
    df['text'], df['labels'], test_size=0.2
)

tokenizer = BertTokenizerFast.from_pretrained(checkpoint, max_length=512)
train_encodings = tokenizer(list(train_texts), truncation=True, padding=True, return_tensors="pt")
val_encodings = tokenizer(list(val_texts), truncation=True, padding=True, return_tensors="pt")



In [3]:
class CustomDataset(Dataset):
    """
    Custom Dataset class for handling tokenized text data and corresponding labels.
    Inherits from torch.utils.data.Dataset.
    """
    def __init__(self, encodings, labels):
        """
        Initializes the DataLoader class with encodings and labels.

        Args:
            encodings (dict): A dictionary containing tokenized input text data
                              (e.g., 'input_ids', 'token_type_ids', 'attention_mask').
            labels (list): A list of integer labels for the input text data.
        """
        self.encodings = encodings
        self.labels = labels

    def __getitem__(self, idx):
        """
        Returns a dictionary containing tokenized data and the corresponding label for a given index.

        Args:
            idx (int): The index of the data item to retrieve.

        Returns:
            item (dict): A dictionary containing the tokenized data and the corresponding label.
        """
        item = {key: torch.tensor(val[idx]).clone().detach() for key, val in self.encodings.items()}
        item['labels'] = torch.tensor(self.labels.iloc[idx]).clone().detach()
        return item

    def __len__(self):
        """
        Returns the number of data items in the dataset.

        Returns:
            (int): The number of data items in the dataset.
        """
        return len(self.labels)

In [4]:
train_dataset = CustomDataset(train_encodings, train_labels)
val_dataset = CustomDataset(val_encodings, val_labels)

num_labels = len(label2id)
model = BertForSequenceClassification.from_pretrained(
    checkpoint,
    num_labels=num_labels,
    id2label=id2label,
    label2id=label2id
).to(device)

training_args = TrainingArguments(
      "train-checkpoints", 
      num_train_epochs=5, 
      eval_strategy="steps", 
      weight_decay=5e-4, 
      per_device_train_batch_size=128,
      per_device_eval_batch_size=128,
      save_strategy="steps",
      logging_steps=50,
      load_best_model_at_end=True,
      run_name="generating-impossible-query",
    #   report_to="wandb",
)

trainer = Trainer(
    # the pre-trained model that will be fine-tuned
    model=model,
     # training arguments that we defined above
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    compute_metrics= compute_metrics
)

trainer.train()

tokenizer.save_pretrained("train-checkpoints/best-model")
trainer.save_model("train-checkpoints/best-model")

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-cased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
  item = {key: torch.tensor(val[idx]).clone().detach() for key, val in self.encodings.items()}
 42%|████▏     | 50/120 [00:34<00:47,  1.46it/s]

{'loss': 0.3714, 'grad_norm': 0.1995501071214676, 'learning_rate': 2.916666666666667e-05, 'epoch': 12.5}


                                                
  item = {key: torch.tensor(val[idx]).clone().detach() for key, val in self.encodings.items()}


{'eval_loss': 0.109756700694561, 'eval_Accuracy': 0.984, 'eval_F1': 0.984650270364556, 'eval_Precision': 0.9866666666666667, 'eval_Recall': 0.9833333333333334, 'eval_runtime': 0.2207, 'eval_samples_per_second': 566.35, 'eval_steps_per_second': 4.531, 'epoch': 12.5}


 83%|████████▎ | 100/120 [01:08<00:13,  1.50it/s]

{'loss': 0.0074, 'grad_norm': 0.049501340836286545, 'learning_rate': 8.333333333333334e-06, 'epoch': 25.0}


                                                 
  item = {key: torch.tensor(val[idx]).clone().detach() for key, val in self.encodings.items()}


{'eval_loss': 0.1372523158788681, 'eval_Accuracy': 0.968, 'eval_F1': 0.9691228070175438, 'eval_Precision': 0.9743589743589745, 'eval_Recall': 0.9666666666666667, 'eval_runtime': 0.215, 'eval_samples_per_second': 581.365, 'eval_steps_per_second': 4.651, 'epoch': 25.0}


100%|██████████| 120/120 [01:23<00:00,  1.43it/s]


{'train_runtime': 83.9069, 'train_samples_per_second': 177.339, 'train_steps_per_second': 1.43, 'train_loss': 0.15829989286139606, 'epoch': 30.0}


In [5]:
from sklearn.metrics import classification_report
import numpy as np

# Make prediction on evaluation dataset
y_pred = trainer.predict(val_dataset).predictions
y_pred = np.argmax(y_pred, axis=-1)

# Get the true labels
y_true = val_dataset.labels
y_true = np.array(y_true)

# Print the classification report
print(classification_report(y_true, y_pred, digits=4))

  item = {key: torch.tensor(val[idx]).clone().detach() for key, val in self.encodings.items()}
100%|██████████| 1/1 [00:00<00:00, 450.61it/s]

              precision    recall  f1-score   support

           0     0.9600    1.0000    0.9796        48
           1     1.0000    0.9500    0.9744        40
           2     1.0000    1.0000    1.0000        37

    accuracy                         0.9840       125
   macro avg     0.9867    0.9833    0.9847       125
weighted avg     0.9846    0.9840    0.9840       125




