# Lab 6: BitFit - Fine-Tuning a BERT Model by Only Training Bias Terms
---
## Notebook 2: The Training Process

**Goal:** In this notebook, you will implement BitFit by manually selecting which parameters of a `bert-base-uncased` model to train.

**You will learn to:**
-   Load a classification dataset and a base model.
-   Write code to iterate through all model parameters and freeze them by default.
-   Selectively unfreeze only the bias (`.bias`) parameters and the final classification head.
-   Use the standard `transformers.Trainer` to fine-tune the partially frozen model.


### Step 1: Load Dataset and Model

This step is identical to the Adapter Layers lab. We will load the GLUE MRPC dataset, preprocess it, and load a `bert-base-uncased` model for sequence classification.


In [None]:
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForSequenceClassification

# --- Load Dataset and Tokenizer ---
dataset = load_dataset("glue", "mrpc")
model_checkpoint = "bert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, use_fast=True)

# --- Preprocessing Function ---
def preprocess_function(examples):
    return tokenizer(examples["sentence1"], examples["sentence2"], truncation=True, padding="max_length")

encoded_dataset = dataset.map(preprocess_function, batched=True)
encoded_dataset = encoded_dataset.rename_column("label", "labels")
encoded_dataset.set_format("torch", columns=["input_ids", "attention_mask", "token_type_ids", "labels"])

# --- Load Model ---
num_labels = 2
model = AutoModelForSequenceClassification.from_pretrained(model_checkpoint, num_labels=num_labels)

print("✅ Dataset and model loaded successfully.")


### Step 2: Implement BitFit by Freezing Parameters

This is the core of the BitFit method. We will manually control which parameters are trainable. The logic is as follows:
1.  First, freeze all parameters in the model by setting `param.requires_grad = False`.
2.  Then, iterate through all parameters again. If a parameter's name contains `.bias`, unfreeze it by setting `param.requires_grad = True`.
3.  Finally, explicitly unfreeze all parameters of the final classification head (named `classifier` in BERT). This is crucial so the model can adapt to the new task's output format.

After this process, we'll print the number of trainable parameters to see how efficient BitFit is.


In [None]:
# Freeze all parameters first
for param in model.parameters():
    param.requires_grad = False

# Unfreeze bias parameters
for name, param in model.named_parameters():
    if ".bias" in name:
        param.requires_grad = True

# Unfreeze the classification head
for param in model.classifier.parameters():
    param.requires_grad = True

# --- Print Trainable Parameters ---
def print_trainable_parameters(model):
    """
    Prints the number of trainable parameters in the model.
    """
    trainable_params = 0
    all_param = 0
    for _, param in model.named_parameters():
        all_param += param.numel()
        if param.requires_grad:
            trainable_params += param.numel()
    print(
        f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param:.2f}"
    )

print_trainable_parameters(model)


### Step 3: Set Up and Run Training

Now that the model is correctly configured for BitFit, we can use the standard `transformers.Trainer` to fine-tune it. The trainer will automatically detect which parameters have `requires_grad=True` and only update those during the optimization process.

The training setup is identical to the Adapter Layers lab.


In [None]:
import numpy as np
from transformers import TrainingArguments, Trainer
import datasets as nlp_datasets

# --- Metrics Calculation Function ---
def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    predictions = np.argmax(predictions, axis=1)
    accuracy_metric = nlp_datasets.load_metric("accuracy")
    f1_metric = nlp_datasets.load_metric("f1")
    accuracy = accuracy_metric.compute(predictions=predictions, references=labels)
    f1 = f1_metric.compute(predictions=predictions, references=labels)
    return {"accuracy": accuracy["accuracy"], "f1": f1["f1"]}

# --- Training Arguments ---
training_args = TrainingArguments(
    output_dir="./bert-bitfit-mrpc",
    learning_rate=3e-4, # BitFit can also benefit from a slightly higher learning rate
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    num_train_epochs=5,
    weight_decay=0.01,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
)

# --- Create Trainer ---
trainer = Trainer(
    model=model, # Note: We are using the modified `model` directly, not a `PeftModel`
    args=training_args,
    train_dataset=encoded_dataset["train"],
    eval_dataset=encoded_dataset["validation"],
    tokenizer=tokenizer,
    compute_metrics=compute_metrics,
)

# --- Start Training ---
print("🚀 Starting training with BitFit...")
trainer.train()
print("✅ Training complete!")
