In [1]:
import torch
from torch.utils.data import DataLoader, Dataset
import pandas as pd
from sklearn.model_selection import train_test_split
from transformers import BertTokenizer
from transformers import BertForSequenceClassification, AdamW
from sklearn.metrics import accuracy_score, f1_score
import matplotlib.pyplot as plt
from tqdm import tqdm
from transformers import get_scheduler
import logging

  from .autonotebook import tqdm as notebook_tqdm


## Dataset

In [2]:
df = pd.read_csv("../data/merged.csv")
df.head()

Unnamed: 0,category,prompt,need_retrieval,title,context,answers
0,closed_qa,When did Virgin Australia start operating? Vir...,0,,,
1,classification,Which is a species of fish? Tope or Rope,0,,,
2,open_qa,Why can camels survive for long without water?,1,,,
3,open_qa,"Alice's parents have three daughters: Amy, Jes...",1,,,
4,closed_qa,When was Tomoaki Komorida born? Komorida was b...,0,,,


The dataset is imbalanced, meaning some classes have significantly more examples than others. \
To address this and ensure fair representation, I decided to use a balanced subset of the data. \
Specifically, I limited the dataset to 1,500 examples per class. \
This approach helps prevent the model from being biased toward the majority class while still providing enough data for training.

In [3]:
num_samples_per_class = 1000
balanced_df = (
    df.groupby("need_retrieval", group_keys=False)
    .apply(
        lambda x: x.sample(num_samples_per_class, random_state=42), include_groups=True
    )
    .reset_index(drop=True)
)
balanced_df = balanced_df.sample(frac=1, random_state=42).reset_index(drop=True)

  .apply(lambda x: x.sample(num_samples_per_class, random_state=42), include_groups=True)


In [4]:
balanced_df.head()

Unnamed: 0,category,prompt,need_retrieval,title,context,answers
0,general_qa,What are the main disadvantages of electric ca...,1,,,
1,,"Write like a noir detective: Adopt the gritty,...",0,,,
2,brainstorming,What are the new 7 Wonders Cities:,1,,,
3,information_extraction,Depict the valuation of Adani group as mention...,0,,,
4,brainstorming,Give me a list of items I should bring to the ...,1,,,


In [5]:
train_df, temp_df = train_test_split(balanced_df, test_size=0.3, random_state=42)
val_df, test_df = train_test_split(temp_df, test_size=0.5, random_state=42)

## Tokenizer

"bert-base-uncased is an effective model for understanding language patterns. \
Its Transformer architecture captures complex word relationships, and its uncased nature simplifies text processing. \
Widely supported and easy to fine-tune, it’s a strong choice for many NLP tasks.

The authors of the article used BERT-basemultilingual-cased. \
I will use only English, therefore there is no need to use multilingual BERT.

In [None]:
tokenizer_path = "bert-base-uncased"
tokenizer = BertTokenizer.from_pretrained(tokenizer_path)


def tokenize_data(df, tokenizer, max_length=512):
    return tokenizer(
        df["prompt"].tolist(),
        padding=True,
        truncation=True,
        max_length=max_length,
        return_tensors="pt",
    )

In [None]:
model_path = "bert-base-uncased"
num_labels = 2

model = BertForSequenceClassification.from_pretrained(model_path, num_labels=num_labels)

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model.to(device)

BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSdpaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e

In [8]:
train_encodings = tokenize_data(train_df, tokenizer)
val_encodings = tokenize_data(val_df, tokenizer)
test_encodings = tokenize_data(test_df, tokenizer)

In [9]:
class TextClassificationDataset(Dataset):
    def __init__(self, encodings, labels, indices):
        self.encodings = encodings
        self.labels = labels
        self.indices = indices

    def __getitem__(self, idx):
        item = {key: val[idx] for key, val in self.encodings.items()}
        item["labels"] = torch.tensor(self.labels[idx])
        item["idx"] = torch.tensor(self.indices[idx])
        return item

    def __len__(self):
        return len(self.labels)


train_labels = train_df["need_retrieval"].tolist()
val_labels = val_df["need_retrieval"].tolist()
test_labels = test_df["need_retrieval"].tolist()

# saving the indices of specific samples in the dataset so that I can easily retrieve them later when needed
train_indices = train_df.index.tolist()
val_indices = val_df.index.tolist()
test_indices = test_df.index.tolist()


train_dataset = TextClassificationDataset(train_encodings, train_labels, train_indices)
val_dataset = TextClassificationDataset(val_encodings, val_labels, val_indices)
test_dataset = TextClassificationDataset(test_encodings, test_labels, test_indices)

In [10]:
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=4, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=4, shuffle=False)

optimizer = AdamW(model.parameters(), lr=5e-5)
num_epochs = 3
num_training_steps = num_epochs * len(train_loader)
lr_scheduler = get_scheduler(
    "linear",
    optimizer=optimizer,
    num_warmup_steps=0,
    num_training_steps=num_training_steps,
)



In [11]:
logging.basicConfig(
    level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger(__name__)


def train_epoch(model, train_loader, optimizer, lr_scheduler, device):
    model.train()
    train_loss = 0
    train_preds, train_labels = [], []

    for batch in tqdm(train_loader, desc="Training", leave=False):
        batch = {k: v.to(device) for k, v in batch.items() if k != "idx"}
        outputs = model(**batch)
        loss = outputs.loss
        loss.backward()

        optimizer.step()
        lr_scheduler.step()
        optimizer.zero_grad()

        train_loss += loss.item()
        logits = outputs.logits
        preds = torch.argmax(logits, dim=-1)
        train_preds.extend(preds.cpu().numpy())
        train_labels.extend(batch["labels"].cpu().numpy())

    avg_train_loss = train_loss / len(train_loader)
    train_accuracy = accuracy_score(train_labels, train_preds)
    train_f1 = f1_score(train_labels, train_preds)

    return avg_train_loss, train_accuracy, train_f1


def validate_epoch(model, val_loader, device):
    model.eval()
    val_loss = 0
    val_preds, val_labels = [], []

    with torch.no_grad():
        for batch in tqdm(val_loader, desc="Validation", leave=False):
            batch = {k: v.to(device) for k, v in batch.items() if k != "idx"}
            outputs = model(**batch)
            logits = outputs.logits
            loss = outputs.loss
            val_loss += loss.item()

            preds = torch.argmax(logits, dim=-1)
            val_preds.extend(preds.cpu().numpy())
            val_labels.extend(batch["labels"].cpu().numpy())

    avg_val_loss = val_loss / len(val_loader)
    val_accuracy = accuracy_score(val_labels, val_preds)
    val_f1 = f1_score(val_labels, val_preds)

    return avg_val_loss, val_accuracy, val_f1

In [12]:
def train(
    model,
    train_loader,
    val_loader,
    optimizer,
    lr_scheduler,
    device,
    num_epochs: int,
    save_dir: str = "./bert-text-classification-model",
    metric: str = "f1",
    early_stopping_patience=None,
):
    best_val_metric = 0
    epochs_without_improvement = 0

    metrics = {
        "train_losses": [],
        "val_losses": [],
        "train_accuracies": [],
        "val_accuracies": [],
        "train_f1_scores": [],
        "val_f1_scores": [],
    }

    for epoch in range(num_epochs):
        logger.info(f"Epoch {epoch + 1}/{num_epochs}")

        avg_train_loss, train_accuracy, train_f1 = train_epoch(
            model, train_loader, optimizer, lr_scheduler, device
        )
        metrics["train_losses"].append(avg_train_loss)
        metrics["train_accuracies"].append(train_accuracy)
        metrics["train_f1_scores"].append(train_f1)
        logger.info(
            f"Training Loss: {avg_train_loss:.4f}, Training Accuracy: {train_accuracy:.4f}, Training F1: {train_f1:.4f}"
        )

        avg_val_loss, val_accuracy, val_f1 = validate_epoch(model, val_loader, device)
        metrics["val_losses"].append(avg_val_loss)
        metrics["val_accuracies"].append(val_accuracy)
        metrics["val_f1_scores"].append(val_f1)
        logger.info(
            f"Validation Loss: {avg_val_loss:.4f}, Validation Accuracy: {val_accuracy:.4f}, Validation F1: {val_f1:.4f}"
        )

        current_val_metric = val_f1 if metric == "f1" else val_accuracy

        if current_val_metric > best_val_metric:
            best_val_metric = current_val_metric
            model.save_pretrained(save_dir)
            logger.info(
                f"New best model saved with Validation {metric.capitalize()}: {best_val_metric:.4f}"
            )
            epochs_without_improvement = 0
        else:
            epochs_without_improvement += 1

        if (
            early_stopping_patience
            and epochs_without_improvement >= early_stopping_patience
        ):
            logger.info(f"Early stopping at epoch {epoch + 1} due to no improvement.")
            break

    logger.info(
        f"Loaded the best model with Validation {metric.capitalize()}: {best_val_metric:.4f}"
    )

    return metrics

In [13]:
def evaluate(
    model,
    test_loader,
    test_df,
    device,
):
    model.eval()
    test_results = []

    with torch.no_grad():
        for batch in test_loader:
            idxs = batch["idx"].cpu().numpy()
            batch = {k: v.to(device) for k, v in batch.items() if k != "idx"}

            outputs = model(**batch)

            logits = outputs.logits
            preds = torch.argmax(logits, dim=-1).cpu().numpy().tolist()
            actual_labels = batch["labels"].cpu().numpy().tolist()

            texts = test_df.loc[idxs, "prompt"].tolist()
            categories = test_df.loc[idxs, "category"].tolist()

            test_results.extend(zip(texts, categories, actual_labels, preds))

    results_df = pd.DataFrame(
        test_results,
        columns=["Text", "Category", "Need_retrieval", "Predicted"],
    )

    logger.info("Test set evaluation completed.")
    return results_df

## Before training

In [14]:
results_df = evaluate(
    model=model, test_loader=test_loader, test_df=test_df, device=device
)
results_df.head()

2025-02-01 23:09:34,109 - INFO - Test set evaluation completed.


Unnamed: 0,Text,Category,Need_retrieval,Predicted
0,How do you play baseball,general_qa,1,1
1,In what area was Frédéric born in?,,1,1
2,Change the text into a medieval speech Climate...,,0,0
3,What types of political organizations did pre-...,closed_qa,0,0
4,What are the four largest British Virgin Islan...,open_qa,1,1


In [None]:
metrics = train(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    optimizer=optimizer,
    lr_scheduler=lr_scheduler,
    device=device,
    num_epochs=4,
    save_dir='./bert-text-classification-model',
    metric="f1",
    early_stopping_patience=3,
)


train_losses = metrics["train_losses"]
val_losses = metrics["val_losses"]
train_f1_scores = metrics["train_f1_scores"]
val_f1_scores = metrics["val_f1_scores"]
train_accuracies = metrics["train_accuracies"]
val_accuracies = metrics["val_accuracies"]

In [None]:
plt.figure(figsize=(12, 8))

plt.subplot(3, 1, 1)
plt.plot(train_losses, label='Training Loss')
plt.plot(val_losses, label='Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training and Validation Loss')
plt.legend()

plt.subplot(3, 1, 2)
plt.plot(train_accuracies, label='Training Accuracy')
plt.plot(val_accuracies, label='Validation Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.title('Training and Validation Accuracy')
plt.legend()

plt.subplot(3, 1, 3)
plt.plot(train_f1_scores, label='Training F1 Score')
plt.plot(val_f1_scores, label='Validation F1 Score')
plt.xlabel('Epoch')
plt.ylabel('F1 Score')
plt.title('Training and Validation F1 Score')
plt.legend()

plt.tight_layout()
plt.show()

## After training

In [17]:
results_df = evaluate(
    model=model, test_loader=test_loader, test_df=test_df, device=device
)

2025-02-01 23:09:57,597 - INFO - Test set evaluation completed.


In [18]:
results_df.head()

Unnamed: 0,Text,Category,Need_retrieval,Predicted
0,How do you play baseball,general_qa,1,1
1,In what area was Frédéric born in?,,1,1
2,Change the text into a medieval speech Climate...,,0,0
3,What types of political organizations did pre-...,closed_qa,0,0
4,What are the four largest British Virgin Islan...,open_qa,1,1


In [19]:
results_df.to_csv("../data/predictions.csv")

In [20]:
test_df_2 = pd.read_csv("../data/test.csv")

test_df_2 = test_df_2.rename(columns={"question": "prompt"})
print(test_df_2.head())
test_encodings_2 = tokenize_data(test_df_2, tokenizer)

test_labels_2 = test_df_2["need_retrieval"].tolist()

# saving the indices of specific samples in the dataset so that I can easily retrieve them later when needed
test_indices_2 = test_df_2.index.tolist()


test_dataset_2 = TextClassificationDataset(
    test_encodings_2, test_labels_2, test_indices_2
)

test_loader_2 = DataLoader(test_dataset_2, batch_size=4, shuffle=False)


                                            title  \
0  Sino-Tibetan_relations_during_the_Ming_dynasty   
1                                             NaN   
2                                             NaN   
3                                         Beyoncé   
4                         2008_Sichuan_earthquake   

                                             context  \
0  Tsai writes that shortly after the visit by De...   
1                                                NaN   
2                                                NaN   
3  Her fourth studio album 4 was released on June...   
4  In the China Digital Times an article reports ...   

                                              prompt  \
0          What did Yongle want to trade with Tibet?   
1  Extract the cinema industry and the percentage...   
2  Adapt the text to make it relevant for a corpo...   
3  What magazine did Beyoncé write a story for ab...   
4           What did the China Digital Times report?   

       

In [21]:
validate_epoch(model, test_loader_2, device)

                                                             

(0.002303662709891796, 1.0, 1.0)

In [22]:
results_df_2 = evaluate(
    model=model, test_loader=test_loader_2, test_df=test_df_2, device=device
)

2025-02-01 23:11:31,466 - INFO - Test set evaluation completed.


In [23]:
results_df_2.head()

Unnamed: 0,Text,Category,Need_retrieval,Predicted
0,What did Yongle want to trade with Tibet?,,1,1
1,Extract the cinema industry and the percentage...,information_extraction,0,0
2,Adapt the text to make it relevant for a corpo...,,0,0
3,What magazine did Beyoncé write a story for ab...,,1,1
4,What did the China Digital Times report?,,1,1
