# This demo tests the effect of different language modelling heads

In [None]:
import sys

sys.path.append("..")  # ensure we can run examples as-is in the package's uv env

In [None]:
import pandas as pd
from datasets import load_dataset
from transformers import AutoTokenizer, TrainingArguments
import torch
from grouphug import AutoMultiTaskModel, ClassificationHeadConfig, DatasetFormatter, LMHeadConfig, MultiTaskTrainer

from utils import compute_classification_metrics

## A basic modelling task similar to the readme example

In [None]:
tweet_emotion = load_dataset("tweet_eval", "emotion").rename_column("label", "emotion")

base_model = "distilbert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(base_model)

formatter = DatasetFormatter().tokenize()
data = formatter.apply(tweet_emotion, tokenizer=tokenizer)

head_configs = [ClassificationHeadConfig.from_data(data, "emotion", classifier_hidden_size=32)]

## Adding different LM heads to a classification task and training

In [None]:
test_lm_heads = {
    "none": [],
    "mlm": [LMHeadConfig(weight=0.2)],
    "mtd": [LMHeadConfig(masked_token_detection=True, weight=0.2)],
    "mlm+mtd": [LMHeadConfig(masked_language_modelling=True, masked_token_detection=True, weight=0.2)],
}
results = {}
training_args = TrainingArguments(
    output_dir="../output",
    evaluation_strategy="epoch",
    num_train_epochs=10,
    save_strategy="no",
)
for test_key, lm_head in test_lm_heads.items():
    model = AutoMultiTaskModel.from_pretrained(
        base_model, head_configs + lm_head, formatter=formatter, tokenizer=tokenizer
    )
    trainer = MultiTaskTrainer(
        model=model,
        tokenizer=tokenizer,
        train_data=data[:, "train"],
        eval_data=data[:, "test"],
        eval_heads=["emotion"],
        compute_metrics=compute_classification_metrics,
        args=training_args,
    )
    trainer.train()
    results[test_key] = pd.DataFrame(trainer.state.log_history)
    model = None
    trainer = None
    torch.cuda.empty_cache()

## Inspecting results

In [None]:
import matplotlib.pyplot as plt

plt.figure(figsize=(20, 5))
for i, k in enumerate(["loss", "eval_loss", "eval_emotion_f1", "eval_emotion_matthews_correlation"]):
    for test_name, df in results.items():
        ax = plt.subplot(1, 4, i + 1)
        df.dropna(subset=k).plot(x="step", y=k, ax=ax)
    plt.legend(results.keys())
    plt.title(k)