# GLUE training
This notebook shows how to fine-tune a model on *all* glue tasks simultaneously, including evaluation metrics.

In [None]:
import sys

sys.path.append("..")  # ensure we can run examples as-is in the package's uv env

In [None]:
import numpy as np
import torch
from datasets import Dataset, DatasetDict, concatenate_datasets, load_dataset, load_metric
from transformers import AutoTokenizer, TrainingArguments

from grouphug import AutoMultiTaskModel, ClassificationHeadConfig, DatasetFormatter, LMHeadConfig, MultiTaskTrainer
from grouphug.config import logger

torch.cuda.is_available()

## Define which model to fine-tune

In [None]:
# transformers.logging.set_verbosity_info()  # uncomment for more logging
base_model = "HannahRoseKirk/Hatemoji"  # a deberta model

## Load data

In [None]:
task_to_keys = {
    "cola": ("sentence", None),  # is this sentence grammatical?
    "mnli": ("premise", "hypothesis"),  # label as neutral, entailment, contradiction
    "mrpc": ("sentence1", "sentence2"),  # whether the sentences in the pair are semantically equivalent.
    "qnli": ("question", "sentence"),  # whether the context sentence contains the answer to the question
    "qqp": ("question1", "question2"),  # determine whether a pair of questions are semantically equivalent.
    "rte": ("sentence1", "sentence2"),  # similar to mnli
    "sst2": ("sentence", None),  # sentiment
    "stsb": ("sentence1", "sentence2"),  # similarity score from 0 to 5.
    "wnli": ("sentence1", "sentence2"),  # entailment
}
tasks = list(task_to_keys.keys())


def load_and_rename(task, reduce_size_target=None):
    k1, k2 = task_to_keys[task]
    dataset = load_dataset("glue", task).rename_column("label", task)

    if k2 is not None:
        dataset = dataset.rename_column(k1, "text1").rename_column(k2, "text2")
    else:
        dataset = dataset.rename_column(k1, "text")

    dataset = DatasetDict(
        {
            "train": dataset["train"],
            "validation": concatenate_datasets([v for k, v in dataset.items() if k.startswith("validation")]),
            "test": concatenate_datasets([v for k, v in dataset.items() if k.startswith("test")]),
        }
    )
    test_labels = dataset["test"].unique(task)
    if reduce_size_target:
        for k, target_size in reduce_size_target.items():
            dataset[k] = Dataset.from_dict(dataset[k][:target_size])
            logger.debug(f"Reducing sizes to {len(dataset[k])} for {k}")
    return dataset

In [None]:
target_size = {"train": 2000, "validation": 100}  # just to keep it quick
glue_data = {task: load_and_rename(task, target_size) for task in tasks}

## Define tokenizer and preprocess data

In [None]:
tokenizer = AutoTokenizer.from_pretrained(base_model)
fmt = DatasetFormatter().tokenize(max_length=512).tokenize(("text1", "text2"), max_length=512)
data = fmt.apply(glue_data, tokenizer=tokenizer, splits=["train", "validation"])

In [None]:
head_configs = [ClassificationHeadConfig.from_data(data, task, detached=False, ignore_index=-1) for task in tasks]
# We fine-tune directly on masked inputs. This works well in practice, but may not work well when single words are very important like Cola.
head_configs += [LMHeadConfig(weight=0.25)]

In [None]:
model = AutoMultiTaskModel.from_pretrained(base_model, head_configs, formatter=fmt, tokenizer=tokenizer)

In [None]:
output_dir = "../output/demo"
training_args = TrainingArguments(
    output_dir=output_dir,
    num_train_epochs=2,
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    gradient_accumulation_steps=8,
    save_total_limit=1,
    evaluation_strategy="epoch",
)

## Define metrics function
Note additional arguments

In [None]:
def compute_metrics(eval_preds, dataset_name, heads):
    metrics_f = load_metric("glue", dataset_name)
    logits, labels = eval_preds
    if dataset_name == "stsb":
        return metrics_f.compute(predictions=logits, references=labels)
    predictions = np.argmax(logits, axis=-1)
    return metrics_f.compute(predictions=predictions, references=labels)

## Train the model

In [None]:
trainer = MultiTaskTrainer(
    model=model,
    tokenizer=tokenizer,
    args=training_args,
    train_data=data[:, "train"],
    eval_data=data[:, "validation"],
    eval_heads={t: [t] for t in tasks},  # for dataset [key], run heads [value]
    compute_metrics=compute_metrics,
)

In [None]:
train_res = trainer.train()

## The model predict function takes dicts or entire datasets and preprocesses, infers, and maps back to labels

In [None]:
model.predict({"text": "The quick brown fox jumped over the lazy dog!"})["cola_predicted_label"]