In [None]:
!pip install transformers



In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
import torch
import numpy as np
import random

# Set a fixed seed value for reproducibility in experiments
seed = 123
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)

np.random.seed(seed)
random.seed(seed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True

## Build the Model
Fine-tuning DistilBERT for QA involves adding a prediction head to output start and end logits for each token:

- Start logits predict the start token of the answer span.
- End logits predict the end token of the answer span.

In [None]:
from transformers import DistilBertModel, DistilBertTokenizerFast, DistilBertConfig
import torch.nn as nn

class DistilBERTQA(nn.Module):
    def __init__(self, model_name='distilbert-base-uncased', num_answer_types=2, dropout_prob=0.1):
        super(DistilBERTQA, self).__init__()
        self.bert = DistilBertModel.from_pretrained(model_name)

        # Dropout layers
        self.dropout_seq = nn.Dropout(dropout_prob)  # For sequence output
        self.dropout_cls = nn.Dropout(dropout_prob)  # For [CLS] token representation

        # Output layers
        self.qa_outputs = nn.Linear(self.bert.config.hidden_size, 2)  # Start and end logits
        self.answer_type_head = nn.Linear(self.bert.config.hidden_size, num_answer_types)  # Type logits

    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        sequence_output = outputs.last_hidden_state  # Shape: (batch_size, seq_len, hidden_size)

        # Apply dropout to sequence output
        sequence_output = self.dropout_seq(sequence_output)

        # Predict start and end logits
        qa_logits = self.qa_outputs(sequence_output)  # Shape: (batch_size, seq_len, 2)
        start_logits, end_logits = qa_logits.split(1, dim=-1)
        start_logits = start_logits.squeeze(-1)  # Shape: (batch_size, seq_len)
        end_logits = end_logits.squeeze(-1)      # Shape: (batch_size, seq_len)

        # Extract [CLS] token representation
        cls_representation = sequence_output[:, 0, :]  # Shape: (batch_size, hidden_size)

        # Apply dropout to [CLS] representation
        cls_representation = self.dropout_cls(cls_representation)

        # Predict answer type logits
        answer_type_logits = self.answer_type_head(cls_representation)  # Shape: (batch_size, num_answer_types)

        return start_logits, end_logits, answer_type_logits


# Load Model and Tokenizer
def load_model():
    model = DistilBERTQA()
    tokenizer = DistilBertTokenizerFast.from_pretrained("distilbert-base-uncased")
    return model, tokenizer

## Dataset and Load Data

In [None]:
"""
Data instance: {
  'name': str,
  'id': str,
  'questions': [{'input_text': query: str}],
  'answers': [
         {
          'candidate_id': id: int,
          'span_text': ans_text: str,
          'span_start': start_index: int,
          'span_end': end_index: int,
          'input_text': type: str ('short' / 'no answer')
          }
  ],
  'has_correct_context': bool,
  'contexts': passage: str
}
"""

import json
from torch.utils.data import Dataset, DataLoader

# Dataset for QA
class QADataset(Dataset):
    def __init__(self, data, tokenizer, max_len=512):
        self.data = data
        self.tokenizer = tokenizer
        self.max_len = max_len

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        question = item["questions"][0]["input_text"]
        context = item["contexts"]
        has_answer = item["has_correct_context"]

        inputs = self.tokenizer(
            question,
            context,
            truncation=True,
            max_length=self.max_len,
            return_offsets_mapping=True,
            padding="max_length"
        )
        offset_mapping = inputs.pop("offset_mapping")

        # Default start and end positions for no-answer ([CLS] token index)
        start_positions = 0
        end_positions = 0
        answer_type = 0  # Default: no answer

        if has_answer:
            answer = item["answers"][0]
            start_char = answer["span_start"]
            end_char = answer["span_end"]

            for idx, (start, end) in enumerate(offset_mapping):
                if start <= start_char < end:
                    start_positions = idx
                if start < end_char <= end:
                    end_positions = idx

            answer_type = 1  # Short answer

        return {
            "input_ids": torch.tensor(inputs["input_ids"]),
            "attention_mask": torch.tensor(inputs["attention_mask"]),
            "start_positions": torch.tensor(start_positions),
            "end_positions": torch.tensor(end_positions),
            "answer_type": torch.tensor(answer_type),
        }



# Load JSON Data
def load_data(train_file, val_file):
    with open(train_file, "r") as f:
        train_data = json.load(f)
    with open(val_file, "r") as f:
        val_data = json.load(f)
    return train_data, val_data

## Train Loop and Evaluation Loop

In [None]:
from tqdm import tqdm
from torch.optim import AdamW

# Training Loop
def train_loop(model, train_data_loader, validation_data_loader, tokenizer, epochs=2, lr=3e-5, weight_decay=0.01, device="cpu"):
    """
    A function to train and validate a given model over a specified number of epochs.

    Args:
        model: The QA model.
        train_data_loader: DataLoader for the training dataset.
        validation_data_loader: DataLoader for the validation dataset.
        tokenizer: Tokenizer used for processing the input text data.
        epochs (int): Number of epochs to train the model (default: 2).
        lr (float): Learning rate for the optimizer (default: 3e-5).
        weight_decay (float): Weight decay parameter for regularization (default: 0.01).
        device (str): Device to train on, e.g., 'cpu' or 'cuda' (default: 'cpu').

    Returns:
        train_losses (list): List of average training losses for each epoch.
        val_losses (list): List of validation losses for each epoch.
    """
    model.to(device)
    optimizer = AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
    train_losses, val_losses = [], []

    for epoch in range(epochs):
        model.train()
        epoch_loss = 0

        for batch in tqdm(train_data_loader, desc=f"Training Epoch {epoch+1}/{epochs}"):
            optimizer.zero_grad()
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            start_positions = batch["start_positions"].to(device)
            end_positions = batch["end_positions"].to(device)
            answer_types = batch["answer_type"].to(device)

            # Forward pass
            start_logits, end_logits, type_logits = model(input_ids, attention_mask)
            loss = compute_loss(start_logits, end_logits, type_logits, start_positions, end_positions, answer_types)

            # Backward pass
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()

        train_losses.append(epoch_loss / len(train_data_loader))

        # Validation Loss
        val_loss = validate(model, validation_data_loader, tokenizer, device)
        val_losses.append(val_loss)

        print(f"Epoch {epoch+1}/{epochs} | Train Loss: {train_losses[-1]:.4f} | Validation Loss: {val_losses[-1]:.4f}")

    return train_losses, val_losses


def validate(model, data_loader, tokenizer, device="cpu"):
    """
    Validate the model on the validation dataset with batch-wise metrics.

    Args:
        model: The QA model.
        data_loader: DataLoader for the validation dataset.
        tokenizer: The tokenizer used to convert input_ids to tokens.
        device: Device to run the model on.

    Returns:
        Average validation loss, batch-wise averaged precision, recall, and F1.
    """
    model.to(device)
    model.eval()
    val_loss = 0

    # To store metrics
    batch_precisions, batch_recalls, batch_f1s = [], [], []

    with torch.no_grad():
        for batch in tqdm(data_loader, desc="Validating"):
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            start_positions = batch["start_positions"].to(device)
            end_positions = batch["end_positions"].to(device)
            answer_types = batch["answer_type"].to(device)

            # Get token lists for the batch
            batch_token_lists = [tokenizer.convert_ids_to_tokens(ids) for ids in input_ids.cpu().numpy()]

            # Forward pass
            start_logits, end_logits, type_logits = model(input_ids, attention_mask)

            # Compute loss
            loss = compute_loss(start_logits, end_logits, type_logits, start_positions, end_positions, answer_types)
            val_loss += loss.item()

            # Get predictions
            start_preds = torch.argmax(start_logits, dim=-1).cpu().numpy()
            end_preds = torch.argmax(end_logits, dim=-1).cpu().numpy()

            # Collect predictions and labels
            batch_predictions = list(zip(start_preds, end_preds))
            batch_labels = list(zip(start_positions.cpu().numpy(), end_positions.cpu().numpy()))

            # Compute metrics for the batch
            precision, recall, f1 = compute_metrics(batch_predictions, batch_labels, batch_token_lists)
            batch_precisions.append(precision)
            batch_recalls.append(recall)
            batch_f1s.append(f1)

    # Average the metrics over all batches
    avg_precision = sum(batch_precisions) / len(batch_precisions)
    avg_recall = sum(batch_recalls) / len(batch_recalls)
    avg_f1 = sum(batch_f1s) / len(batch_f1s)

    print(f"Span Precision: {avg_precision:.4f}, Recall: {avg_recall:.4f}, F1: {avg_f1:.4f}")

    return val_loss / len(data_loader)


# Compute Loss
def compute_loss(start_logits, end_logits, type_logits, start_positions, end_positions, answer_types):
    # Compute start and end loss
    start_loss = torch.nn.functional.cross_entropy(start_logits, start_positions)
    end_loss = torch.nn.functional.cross_entropy(end_logits, end_positions)

    # Compute type loss
    type_loss = torch.nn.functional.cross_entropy(type_logits, answer_types)

    # Combine the losses
    total_loss = start_loss + end_loss + type_loss
    return total_loss


def eval_loop(model, data_loader, tokenizer, device="cpu"):
    """
    Evaluate the model on the test dataset with batch-wise metrics.

    Args:
        model: The QA model.
        data_loader: DataLoader for the test dataset.
        tokenizer: The tokenizer used to convert input_ids to tokens.
        device: Device to run the model on.

    Returns:
        Batch-wise averaged Precision, Recall, and F1 score for spans.
    """
    model.to(device)
    model.eval()

    # To store metrics
    batch_precisions, batch_recalls, batch_f1s = [], [], []

    with torch.no_grad():
        for batch in tqdm(data_loader, desc="Evaluating"):
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            start_positions = batch["start_positions"].to(device)
            end_positions = batch["end_positions"].to(device)

            # Get token lists for the batch
            batch_token_lists = [tokenizer.convert_ids_to_tokens(ids) for ids in input_ids.cpu().numpy()]

            # Forward pass
            start_logits, end_logits, type_logits = model(input_ids, attention_mask)

            # Get predictions
            start_preds = torch.argmax(start_logits, dim=-1).cpu().numpy()
            end_preds = torch.argmax(end_logits, dim=-1).cpu().numpy()

            # Collect predictions and labels
            batch_predictions = list(zip(start_preds, end_preds))
            batch_labels = list(zip(start_positions.cpu().numpy(), end_positions.cpu().numpy()))

            # Compute metrics for the batch
            precision, recall, f1 = compute_metrics(batch_predictions, batch_labels, batch_token_lists)
            batch_precisions.append(precision)
            batch_recalls.append(recall)
            batch_f1s.append(f1)

    # Average the metrics over all batches
    avg_precision = sum(batch_precisions) / len(batch_precisions)
    avg_recall = sum(batch_recalls) / len(batch_recalls)
    avg_f1 = sum(batch_f1s) / len(batch_f1s)

    return avg_precision, avg_recall, avg_f1



def compute_metrics(predictions, labels, tokenized_inputs):
    """
    Compute precision, recall, and F1 score based on token overlaps.

    Args:
        predictions: List of tuples [(start_pred, end_pred), ...] for each example.
        labels: List of tuples [(start_label, end_label), ...] for each example.
        tokenized_inputs: List of token lists corresponding to each example.

    Returns:
        precision, recall, f1
    """
    total_tp = 0
    total_fp = 0
    total_fn = 0

    for (pred, label, tokens) in zip(predictions, labels, tokenized_inputs):
        (s_pred, e_pred) = pred
        (s_label, e_label) = label

        # Ensure start <= end
        if s_pred > e_pred:
            predicted_tokens = set()
        else:
            predicted_tokens = set(tokens[s_pred : e_pred + 1])

        if s_label > e_label:
            gold_tokens = set()
        else:
            gold_tokens = set(tokens[s_label : e_label + 1])

        # Calculate true positives, false positives, and false negatives
        tp = len(predicted_tokens & gold_tokens)
        fp = len(predicted_tokens - gold_tokens)
        fn = len(gold_tokens - predicted_tokens)

        total_tp += tp
        total_fp += fp
        total_fn += fn

    # Compute precision, recall, f1
    precision = total_tp / (total_tp + total_fp) if (total_tp + total_fp) > 0 else 0
    recall = total_tp / (total_tp + total_fn) if (total_tp + total_fn) > 0 else 0
    f1 = (2 * precision * recall / (precision + recall)) if (precision + recall) > 0 else 0

    return precision, recall, f1

In [None]:
# Main Function
def main():
    DIR_PATH = "/content/drive/MyDrive/Final Project/Data/"
    train_file = DIR_PATH + "all_train.json"
    val_file = DIR_PATH + "all_dev.json"
    device = "cuda" if torch.cuda.is_available() else "cpu"
    batch_size = 32

    model, tokenizer = load_model()
    train_data, validation_data = load_data(train_file, val_file)

    train_data_loader = DataLoader(QADataset(train_data, tokenizer), batch_size=batch_size, shuffle=True)
    validation_data_loader = DataLoader(QADataset(validation_data, tokenizer), batch_size=1)

    train_losses, val_losses = train_loop(model, train_data_loader, validation_data_loader, tokenizer, device=device)
    precision, recall, f1_score = eval_loop(model, validation_data_loader, tokenizer, device=device)

    print("\n\n")
    print("PRECISION: ", precision)
    print("RECALL: ", recall)
    print("F1-SCORE: ", f1_score)

if __name__ == "__main__":
    main()

Training Epoch 1/2: 100%|██████████| 871/871 [22:40<00:00,  1.56s/it]
Validating: 100%|██████████| 1743/1743 [00:30<00:00, 57.39it/s]


Span Precision: 0.6530, Recall: 0.7101, F1: 0.6509
Epoch 1/2 | Train Loss: 3.7811 | Validation Loss: 2.7316


Training Epoch 2/2: 100%|██████████| 871/871 [22:40<00:00,  1.56s/it]
Validating: 100%|██████████| 1743/1743 [00:30<00:00, 57.32it/s]


Span Precision: 0.7111, Recall: 0.7312, F1: 0.6965
Epoch 2/2 | Train Loss: 2.2436 | Validation Loss: 2.4962


Evaluating: 100%|██████████| 1743/1743 [00:30<00:00, 57.18it/s]




PRECISION:  0.711126624192353
RECALL:  0.7311867682020301
F1-SCORE:  0.6965455687260241



