<a href="https://colab.research.google.com/github/Jae-YS/BERT_model/blob/main/Fine_tune_BERT.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install torch==2.2.1 transformers==4.40.0 datasets==2.16.0 evaluate==0.4.0 accelerate==0.32.0 peft==0.10.0


Collecting transformers==4.40.0
  Downloading transformers-4.40.0-py3-none-any.whl.metadata (137 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/137.6 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m137.6/137.6 kB[0m [31m9.8 MB/s[0m eta [36m0:00:00[0m
Collecting tokenizers<0.20,>=0.19 (from transformers==4.40.0)
  Downloading tokenizers-0.19.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (6.7 kB)
Downloading transformers-4.40.0-py3-none-any.whl (9.0 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m9.0/9.0 MB[0m [31m125.0 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading tokenizers-0.19.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (3.6 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.6/3.6 MB[0m [31m14.6 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: tokenizers, transformers
  Attempting uninstall: tokenizer

In [None]:
from datasets import load_dataset

# Load the SQuAD v1 dataset
dataset = load_dataset("squad")

In [None]:
from transformers import BertForQuestionAnswering, BertTokenizerFast

# Load pre-trained BERT model and tokenizer from Hugging Face
model = BertForQuestionAnswering.from_pretrained("bert-base-uncased")
tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")

# Preprocessing Function: Converts raw examples into tokenized features
def preprocess_function(examples):
    # Strip whitespace from questions and contexts
    questions = [q.strip() for q in examples["question"]]
    contexts = [c.strip() for c in examples["context"]]

    # Tokenize the QA pairs with chunking for long contexts
    tokenized_inputs = tokenizer(
        questions,
        contexts,
        max_length=384,
        truncation="only_second",              # Only truncate the context (not the question)
        stride=128,                            # Use a stride for overlapping chunks
        return_overflowing_tokens=True,        # Enable splitting of long contexts
        return_offsets_mapping=True,           # Keep mapping from tokens to character positions
        padding="max_length",                  # Pad all sequences to max_length
    )

    # Track which original example maps to each tokenized chunk
    sample_map = tokenized_inputs["overflow_to_sample_mapping"]
    offset_mapping = tokenized_inputs["offset_mapping"]

    start_positions = []
    end_positions = []
    example_ids = []
    updated_offset_mapping = []

    # For each tokenized chunk
    for i, offsets in enumerate(offset_mapping):
        input_ids = tokenized_inputs["input_ids"][i]
        cls_index = input_ids.index(tokenizer.cls_token_id)
        sequence_ids = tokenized_inputs.sequence_ids(i)

        # Map back to original sample
        sample_index = sample_map[i]
        answers = examples["answers"][sample_index]
        example_ids.append(examples["id"][sample_index])

        # Adjust offset mapping: only keep offsets for context tokens
        updated_offsets = [
            o if sequence_ids[k] == 1 else None
            for k, o in enumerate(offsets)
        ]
        updated_offset_mapping.append(updated_offsets)

        # If no answer, point to [CLS]
        if len(answers["answer_start"]) == 0 or answers["answer_start"][0] is None:
            start_positions.append(cls_index)
            end_positions.append(cls_index)
            continue

        # Character-level start/end positions
        start_char = answers["answer_start"][0]
        end_char = start_char + len(answers["text"][0])

        # Find token start index in context
        token_start_index = 0
        while sequence_ids[token_start_index] != 1:
            token_start_index += 1

        # Find token end index in context
        token_end_index = len(input_ids) - 1
        while sequence_ids[token_end_index] != 1:
            token_end_index -= 1

        # If answer not fully in chunk, set to [CLS]
        if not (offsets[token_start_index][0] <= start_char and offsets[token_end_index][1] >= end_char):
            start_positions.append(cls_index)
            end_positions.append(cls_index)
        else:
            # Find token index that contains the answer's start character
            for idx in range(token_start_index, token_end_index + 1):
                if offsets[idx][0] <= start_char < offsets[idx][1]:
                    start_positions.append(idx)
                    break
            else:
                start_positions.append(cls_index)

            # Find token index that contains the answer's end character
            for idx in range(token_end_index, token_start_index - 1, -1):
                if offsets[idx][0] < end_char <= offsets[idx][1]:
                    end_positions.append(idx)
                    break
            else:
                end_positions.append(cls_index)

    # Add start/end labels and mapping for use in training and evaluation
    tokenized_inputs["start_positions"] = start_positions
    tokenized_inputs["end_positions"] = end_positions
    tokenized_inputs["example_id"] = example_ids
    tokenized_inputs["offset_mapping"] = updated_offset_mapping

    return tokenized_inputs

# Apply Preprocessing to Entire Dataset
tokenized_datasets = dataset.map(
    preprocess_function,
    batched=True,                                         # Apply preprocessing to batches
    remove_columns=dataset["train"].column_names,         # Remove original columns to avoid conflicts
    desc="Tokenizing SQuAD",                              # Progress bar label
)


Some weights of BertForQuestionAnswering were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['qa_outputs.bias', 'qa_outputs.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Tokenizing SQuAD:   0%|          | 0/10570 [00:00<?, ? examples/s]

In [None]:
# Columns required for training with PyTorch-compatible format
columns_to_keep = ['input_ids', 'attention_mask', 'offset_mapping', 'start_positions', 'end_positions']

# Set format for training set to return PyTorch tensors
tokenized_datasets["train"].set_format("torch", columns=columns_to_keep)

# Set format for validation set to return native Python objects (needed for postprocessing with offset_mapping)
tokenized_datasets["validation"] = tokenized_datasets["validation"].with_format("python")

# Full datasets used for final training and evaluation
full_train = tokenized_datasets["train"]
full_val = tokenized_datasets["validation"]

# Subsets for Experimentation

# Very small subset (for debugging or rapid testing)
small_train_dataset = tokenized_datasets["train"].select(range(500))
small_eval_dataset = tokenized_datasets["validation"].select(range(100))

# Medium-sized subset (for tuning hyperparameters before scaling up)
mid_train_dataset = tokenized_datasets["train"].select(range(2000))
mid_eval_dataset = tokenized_datasets["validation"].select(range(500))

# Large subset (near full scale; good for final training before full set)
larger_train_dataset = tokenized_datasets["train"].select(range(10000))
larger_eval_dataset = tokenized_datasets["validation"].select(range(2000))


In [None]:
# BERT QA Training Pipeline

# Hugging Face imports
from transformers import (
    BertForQuestionAnswering,
    BertTokenizerFast,
    TrainingArguments,
    Trainer,
    EarlyStoppingCallback
)
import numpy as np
import evaluate
import string

squad_metric = evaluate.load("squad")

# Compute Metrics Callback: Converts model predictions to answer strings, then calculates EM/F1
def compute_metrics(p):
    predictions = postprocess_qa_predictions(
        examples=dataset["validation"],                   # Raw examples (original context and questions)
        features=tokenized_datasets["validation"],        # Tokenized features (with chunking)
        raw_predictions=p.predictions,                    # Model outputs: (start_logits, end_logits)
        tokenizer=tokenizer,
        n_best_size=20,
        max_answer_length=30
    )

    formatted_predictions = [
        {"id": k, "prediction_text": v} for k, v in predictions.items()
    ]
    references = [
        {"id": ex["id"], "answers": ex["answers"]} for ex in dataset["validation"]
    ]

    return squad_metric.compute(predictions=formatted_predictions, references=references)

# Custom Postprocessing Function: Converts start/end logits into answer spans
def postprocess_qa_predictions(
    examples,
    features,
    raw_predictions,
    tokenizer,
    n_best_size=20,
    max_answer_length=30,
    score_threshold=0.0
):
    import collections
    import numpy as np

    all_start_logits, all_end_logits = raw_predictions

    # Map example ID to features
    example_id_to_index = {k["id"]: i for i, k in enumerate(examples)}
    features_per_example = collections.defaultdict(list)
    for i, feature in enumerate(features):
        # Use 'example_id' added during tokenization or fall back to 'id'
        features_per_example[feature.get("example_id", feature["id"])].append(i)

    predictions = collections.OrderedDict()

    for example in examples:
        example_id = example["id"]
        context = example["context"]
        feature_indices = features_per_example[example_id]

        valid_answers = []

        for i in feature_indices:
            start_logits = all_start_logits[i]
            end_logits = all_end_logits[i]
            offset_mapping = features[i]["offset_mapping"]
            input_ids = features[i]["input_ids"]

            # Top n_best start and end token positions
            start_indexes = np.argsort(start_logits)[-1: -n_best_size - 1: -1].tolist()
            end_indexes = np.argsort(end_logits)[-1: -n_best_size - 1: -1].tolist()

            for start_index in start_indexes:
                for end_index in end_indexes:
                    # Skip invalid spans
                    if (
                        start_index >= len(offset_mapping)
                        or end_index >= len(offset_mapping)
                        or offset_mapping[start_index] is None
                        or offset_mapping[end_index] is None
                        or end_index < start_index
                        or end_index - start_index + 1 > max_answer_length
                    ):
                        continue

                    start_char = offset_mapping[start_index][0]
                    end_char = offset_mapping[end_index][1]
                    answer_text = context[start_char:end_char]
                    score = start_logits[start_index] + end_logits[end_index]

                    valid_answers.append({
                        "score": score,
                        "text": answer_text.strip()
                    })

        # Select the best valid answer span
        if valid_answers:
            best_answer = max(valid_answers, key=lambda x: x["score"])
            predictions[example_id] = "" if best_answer["score"] < score_threshold else best_answer["text"]
        else:
            predictions[example_id] = ""

    return predictions

# Training Arguments
training_args = TrainingArguments(
    output_dir="./bert_qa_results",               # Save directory for checkpoints
    learning_rate=3e-5,                            # Learning rate
    per_device_train_batch_size=64,               # Training batch size per GPU
    gradient_accumulation_steps=1,                # No accumulation
    evaluation_strategy="epoch",                  # Evaluate at the end of each epoch
    save_strategy="epoch",                        # Save checkpoints at each epoch
    num_train_epochs=5,                           # Max 5 epochs
    weight_decay=0.01,                            # Regularization
    fp16=True,                                     # Mixed precision for speed
    report_to="wandb",                             # Log to Weights & Biases
    load_best_model_at_end=True,                   # Automatically keep the best checkpoint
    metric_for_best_model="f1",                    # Use F1 as best model criterion
    greater_is_better=True,                        # Higher F1 is better
    lr_scheduler_type="linear",                    # Linear LR decay
    warmup_ratio=0.1,                              # 10% LR warmup
    logging_steps=50                               # Log every 50 steps
)

# Trainer
trainer = Trainer(
    model=model,                                   # BERT QA model
    args=training_args,                            # Training config
    train_dataset=full_train,                      # Full tokenized training set
    eval_dataset=full_val,                         # Full tokenized validation set
    tokenizer=tokenizer,                           # Tokenizer (for saving + decoding)
    compute_metrics=compute_metrics,               # Metrics function (EM + F1)
    callbacks=[EarlyStoppingCallback(early_stopping_patience=1)],  # Stop early if no F1 gain
)


In [None]:
#Fine-Tune
trainer.train()
trainer.save_model("./bert_qa_results")

results = trainer.evaluate()
print(results)


Epoch,Training Loss,Validation Loss,Exact Match,F1
1,1.1438,1.152853,75.099338,83.49841
2,0.8962,1.085218,76.982025,85.007131
3,0.6567,1.131582,77.445601,85.189801
4,0.5079,1.217927,77.029328,84.975843


{'eval_loss': 1.1315820217132568, 'eval_exact_match': 77.44560075685904, 'eval_f1': 85.18980079450432, 'eval_runtime': 80.0299, 'eval_samples_per_second': 134.75, 'eval_steps_per_second': 16.844, 'epoch': 4.0}


In [None]:
# Run Predictions and Display Sample Results

# Step 1: Run the trained model on the validation set
raw_preds = trainer.predict(tokenized_datasets["validation"])

# Step 2: Postprocess model logits to get readable answer spans
predictions = postprocess_qa_predictions(
    examples=dataset["validation"],                # Raw examples with questions + context
    features=tokenized_datasets["validation"],     # Tokenized features (chunked and aligned)
    raw_predictions=raw_preds.predictions,         # Tuple of (start_logits, end_logits)
    tokenizer=tokenizer,
    n_best_size=40,                                # Number of top start/end positions to consider
    max_answer_length=30,                          # Maximum allowed length for answer span
    score_threshold=3.0                            # Ignore low-confidence spans
)

# Step 3: View the first 5 prediction results
for i in range(5):
    example = dataset["validation"][i]                 # Original example
    example_id = example["id"]                         # Unique ID to match with prediction
    question = example["question"]
    context = example["context"]
    ground_truth = example["answers"]["text"][0]       # First annotated answer
    predicted_answer = predictions[example_id]         # Our model's predicted answer

    # Get a preview of the context (first 200 characters, single-line)
    snippet = context[:200].replace("\n", " ")

    # Print formatted comparison
    print(f"Example {i + 1}")
    print(f"Question:     {question}")
    print(f"Ground Truth: {ground_truth}")
    print(f"Prediction:   {predicted_answer}")
    print(f"Context Snippet: {snippet}...")
    print("-" * 80)
