# Hyperparameter Tuning
*(Note: This notebook runs significantly faster if you have access to a GPU. Use either the GPUHub, Google Colab, or your own GPU.)*

In this project, you will optimize the hyperparameters of a model in 3 stages.

## Paraphrase Detection
We finetune [distilbert-base-uncased](https://huggingface.co/distilbert-base-uncased) on [MRPC](https://huggingface.co/datasets/glue/viewer/mrpc/train), a paraphrase detection dataset. This notebook is adapted from a [PyTorch Lightning example](https://lightning.ai/docs/pytorch/1.9.5/notebooks/lightning_examples/text-transformers.html).

In [None]:
%pip install -q torch transformers lightning datasets wandb evaluate ipywidgets

The next 4 cells are:
* Imports
* The `GLUEDataModule` loads the task's dataset and creates dataloaders for the train and valid sets.
* The `GLUETransformer` implements the model forward pass and the training/validation steps. You can check here what is logged with the `self.log` calls.
* The last cell runs training with the given parameters.

In [None]:
from datetime import datetime
from typing import Optional

import wandb
import datasets
import evaluate
import lightning as L
import torch
from pytorch_lightning.loggers import WandbLogger
from torch.utils.data import DataLoader
from transformers import (
    AutoConfig,
    AutoModelForSequenceClassification,
    AutoTokenizer,
    get_linear_schedule_with_warmup,
)

In [None]:
wandb.login()

In [None]:
class GLUEDataModule(L.LightningDataModule):
    task_text_field_map = {
        "cola": ["sentence"],
        "sst2": ["sentence"],
        "mrpc": ["sentence1", "sentence2"],
        "qqp": ["question1", "question2"],
        "stsb": ["sentence1", "sentence2"],
        "mnli": ["premise", "hypothesis"],
        "qnli": ["question", "sentence"],
        "rte": ["sentence1", "sentence2"],
        "wnli": ["sentence1", "sentence2"],
        "ax": ["premise", "hypothesis"],
    }

    glue_task_num_labels = {
        "cola": 2,
        "sst2": 2,
        "mrpc": 2,
        "qqp": 2,
        "stsb": 1,
        "mnli": 3,
        "qnli": 2,
        "rte": 2,
        "wnli": 2,
        "ax": 3,
    }

    loader_columns = [
        "datasets_idx",
        "input_ids",
        "token_type_ids",
        "attention_mask",
        "start_positions",
        "end_positions",
        "labels",
    ]

    def __init__(
        self,
        model_name_or_path: str,
        task_name: str = "mrpc",
        max_seq_length: int = 128,
        train_batch_size: int = 32,
        eval_batch_size: int = 32,
        **kwargs,
    ):
        super().__init__()
        self.model_name_or_path = model_name_or_path
        self.task_name = task_name
        self.max_seq_length = max_seq_length
        self.train_batch_size = train_batch_size
        self.eval_batch_size = eval_batch_size

        self.text_fields = self.task_text_field_map[task_name]
        self.num_labels = self.glue_task_num_labels[task_name]
        self.tokenizer = AutoTokenizer.from_pretrained(self.model_name_or_path, use_fast=True)

    def setup(self, stage: str):
        self.dataset = datasets.load_dataset("glue", self.task_name)

        for split in self.dataset.keys():
            self.dataset[split] = self.dataset[split].map(
                self.convert_to_features,
                batched=True,
                remove_columns=["label"],
            )
            self.columns = [c for c in self.dataset[split].column_names if c in self.loader_columns]
            self.dataset[split].set_format(type="torch", columns=self.columns)

        self.eval_splits = [x for x in self.dataset.keys() if "validation" in x]

    def prepare_data(self):
        datasets.load_dataset("glue", self.task_name)
        AutoTokenizer.from_pretrained(self.model_name_or_path, use_fast=True)

    def train_dataloader(self):
        return DataLoader(self.dataset["train"], batch_size=self.train_batch_size, shuffle=True)

    def val_dataloader(self):
        if len(self.eval_splits) == 1:
            return DataLoader(self.dataset["validation"], batch_size=self.eval_batch_size)
        elif len(self.eval_splits) > 1:
            return [DataLoader(self.dataset[x], batch_size=self.eval_batch_size) for x in self.eval_splits]

    def test_dataloader(self):
        if len(self.eval_splits) == 1:
            return DataLoader(self.dataset["test"], batch_size=self.eval_batch_size)
        elif len(self.eval_splits) > 1:
            return [DataLoader(self.dataset[x], batch_size=self.eval_batch_size) for x in self.eval_splits]

    def convert_to_features(self, example_batch, indices=None):
        # Either encode single sentence or sentence pairs
        if len(self.text_fields) > 1:
            texts_or_text_pairs = list(zip(example_batch[self.text_fields[0]], example_batch[self.text_fields[1]]))
        else:
            texts_or_text_pairs = example_batch[self.text_fields[0]]

        # Tokenize the text/text pairs
        features = self.tokenizer.batch_encode_plus(
            texts_or_text_pairs, max_length=self.max_seq_length, padding="max_length", truncation=True
        )

        # Rename label to labels to make it easier to pass to model forward
        features["labels"] = example_batch["label"]

        return features

In [None]:
learning_rate: float = 2e-5  # Range: 1e-5 to 5e-5 (log scale)
weight_decay: float = 0.01  # Range: 0.0 to 0.1
adam_epsilon: float = 1e-8  # Range: 1e-9 to 1e-7 (log scale)
train_batch_size: int = 32  # Keep fixed for consistent comparison
eval_batch_size: int = 32  # Keep fixed for consistent comparison
warmup_steps: int = 100  # Range: 0 to 500
dropout_rate: float = 0.1  # Range: 0.0 to 0.3
max_seq_length: int = 256  # Options: 64, 128, 256
gradient_clip_val: float = 1.0  # Range: 0.5 to 2.0
num_epochs: int = 3 # Fixed parameter

In [None]:
class GLUETransformer(L.LightningModule):
    def __init__(
        self,
        model_name_or_path: str,
        num_labels: int,
        task_name: str,
        learning_rate: float = learning_rate,
        warmup_steps: int = warmup_steps,
        weight_decay: float = weight_decay,
        adam_epsilon: float = adam_epsilon,
        dropout_rate: float = dropout_rate,
        train_batch_size: int = train_batch_size,
        eval_batch_size: int = eval_batch_size,
        eval_splits: Optional[list] = None,
        **kwargs,
    ):
        super().__init__()

        self.save_hyperparameters()

        # Load config and set dropout
        self.config = AutoConfig.from_pretrained(model_name_or_path, num_labels=num_labels)
        self.config.attention_probs_dropout_prob = dropout_rate
        self.config.hidden_dropout_prob = dropout_rate

        self.model = AutoModelForSequenceClassification.from_pretrained(model_name_or_path, config=self.config)
        self.metric = evaluate.load(
            "glue", self.hparams.task_name, experiment_id=datetime.now().strftime("%d-%m-%Y_%H-%M-%S")
        )

        self.validation_step_outputs = []
        self.training_step_outputs = []

        # Track best validation accuracy
        self.best_val_accuracy = 0.0

    def forward(self, **inputs):
        return self.model(**inputs)

    def training_step(self, batch, batch_idx):
        outputs = self(**batch)
        loss = outputs[0]
        logits = outputs[1]

        # Calculate predictions for accuracy
        if self.hparams.num_labels > 1:
            preds = torch.argmax(logits, axis=1)
        elif self.hparams.num_labels == 1:
            preds = logits.squeeze()

        labels = batch["labels"]

        # Store outputs for epoch-end calculation
        self.training_step_outputs.append({"loss": loss, "preds": preds, "labels": labels})

        # Log loss for each step
        self.log("train_loss_step", loss, prog_bar=False, on_step=True, on_epoch=False)

        return loss

    def on_train_epoch_end(self):
        # Calculate average training loss
        avg_loss = torch.stack([x["loss"] for x in self.training_step_outputs]).mean()

        # Calculate training accuracy
        preds = torch.cat([x["preds"] for x in self.training_step_outputs]).detach().cpu().numpy()
        labels = torch.cat([x["labels"] for x in self.training_step_outputs]).detach().cpu().numpy()

        train_metrics = self.metric.compute(predictions=preds, references=labels)
        train_accuracy = train_metrics.get('accuracy', train_metrics.get('f1', 0.0))

        # Log training metrics
        self.log("train_loss", avg_loss, prog_bar=True, on_epoch=True)
        self.log("train_accuracy", train_accuracy, prog_bar=True, on_epoch=True)

        # Clear outputs
        self.training_step_outputs.clear()

    def validation_step(self, batch, batch_idx, dataloader_idx=0):
        outputs = self(**batch)
        val_loss, logits = outputs[:2]

        if self.hparams.num_labels > 1:
            preds = torch.argmax(logits, axis=1)
        elif self.hparams.num_labels == 1:
            preds = logits.squeeze()

        labels = batch["labels"]
        self.validation_step_outputs.append({"loss": val_loss, "preds": preds, "labels": labels})
        return val_loss

    def on_validation_epoch_end(self):
        if self.hparams.task_name == "mnli":
            for i, output in enumerate(self.validation_step_outputs):
                # matched or mismatched
                split = self.hparams.eval_splits[i].split("_")[-1]
                preds = torch.cat([x["preds"] for x in output]).detach().cpu().numpy()
                labels = torch.cat([x["labels"] for x in output]).detach().cpu().numpy()
                loss = torch.stack([x["loss"] for x in output]).mean()
                self.log(f"val_loss_{split}", loss, prog_bar=True)
                split_metrics = {
                    f"{k}_{split}": v for k, v in self.metric.compute(predictions=preds, references=labels).items()
                }
                self.log_dict(split_metrics, prog_bar=True)
            self.validation_step_outputs.clear()
            return

        preds = torch.cat([x["preds"] for x in self.validation_step_outputs]).detach().cpu().numpy()
        labels = torch.cat([x["labels"] for x in self.validation_step_outputs]).detach().cpu().numpy()
        loss = torch.stack([x["loss"] for x in self.validation_step_outputs]).mean()

        # Calculate validation metrics
        val_metrics = self.metric.compute(predictions=preds, references=labels)
        val_accuracy = val_metrics.get('accuracy', val_metrics.get('f1', 0.0))

        # Update best validation accuracy
        if val_accuracy > self.best_val_accuracy:
            self.best_val_accuracy = val_accuracy

        # Log validation metrics
        self.log("val_loss", loss, prog_bar=True)
        self.log("val_accuracy", val_accuracy, prog_bar=True)
        self.log("best_val_accuracy", self.best_val_accuracy, prog_bar=True)
        self.log_dict(val_metrics, prog_bar=True)

        self.validation_step_outputs.clear()

    def configure_optimizers(self):
        """Prepare optimizer and schedule (linear warmup and decay)"""
        model = self.model
        no_decay = ["bias", "LayerNorm.weight"]
        optimizer_grouped_parameters = [
            {
                "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
                "weight_decay": self.hparams.weight_decay,
            },
            {
                "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
                "weight_decay": 0.0,
            },
        ]
        optimizer = torch.optim.AdamW(
            optimizer_grouped_parameters,
            lr=self.hparams.learning_rate,
            eps=self.hparams.adam_epsilon
        )

        scheduler = get_linear_schedule_with_warmup(
            optimizer,
            num_warmup_steps=self.hparams.warmup_steps,
            num_training_steps=self.trainer.estimated_stepping_batches,
        )
        scheduler = {"scheduler": scheduler, "interval": "step", "frequency": 1}
        return [optimizer], [scheduler]


In [None]:
# Seed
L.seed_everything(42)

# Data and model
dm = GLUEDataModule(
    model_name_or_path="distilbert-base-uncased",
    task_name="mrpc",
    max_seq_length=max_seq_length,
    train_batch_size=train_batch_size,
    eval_batch_size=eval_batch_size,
)
dm.setup("fit")

model = GLUETransformer(
    model_name_or_path="distilbert-base-uncased",
    num_labels=dm.num_labels,
    eval_splits=dm.eval_splits,
    task_name=dm.task_name,
    learning_rate=learning_rate,
    warmup_steps=warmup_steps,
    weight_decay=weight_decay,
    adam_epsilon=adam_epsilon,
    dropout_rate=dropout_rate,
    train_batch_size=train_batch_size,
    eval_batch_size=eval_batch_size,
)

# W&B run name with all hyperparameters
run_name = (
    f"distilbert_lr{learning_rate}_wd{weight_decay}_"
    f"warmup{warmup_steps}_eps{adam_epsilon}_"
    f"dropout{dropout_rate}_maxseq{max_seq_length}_"
    f"gradclip{gradient_clip_val}"
)

# WandbLogger will automatically log model hyperparameters from save_hyperparameters()
wandb_logger = WandbLogger(
    project="HyperparameterTuning",
    name=run_name,
    log_model=False,
)

# Trainer with gradient clipping
trainer = L.Trainer(
    max_epochs=num_epochs,
    accelerator="auto",
    devices=1,
    logger=wandb_logger,
    gradient_clip_val=gradient_clip_val,
)

# Train
trainer.fit(model, datamodule=dm)
