# Fine-tune a Hugging Face Text Classification Model

This notebook mirrors the functionality of `train.py` and walks through the steps required
to fine-tune a Hugging Face text classification model on a JSONL dataset containing
`text` and `labels` fields.

In [None]:
from __future__ import annotations

import logging
from dataclasses import dataclass
from pathlib import Path
from typing import Optional

import evaluate
from datasets import DatasetDict, load_dataset
from transformers import (
    AutoModelForSequenceClassification,
    AutoTokenizer,
    DataCollatorWithPadding,
    Trainer,
    TrainingArguments,
)

In [None]:
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)

In [None]:
@dataclass
class ScriptArguments:
    """Configuration settings used throughout the fine-tuning workflow."""

    model_name_or_path: str
    train_file: Path
    validation_file: Optional[Path]
    output_dir: Path
    max_length: int = 512
    learning_rate: float = 5e-5
    per_device_train_batch_size: int = 8
    per_device_eval_batch_size: int = 8
    num_train_epochs: int = 3
    weight_decay: float = 0.0
    warmup_ratio: float = 0.0
    logging_steps: int = 50
    eval_steps: Optional[int] = None
    seed: int = 42


def load_json_dataset(train_path: Path, validation_path: Optional[Path]) -> DatasetDict:
    data_files = {"train": str(train_path)}
    if validation_path is not None:
        data_files["validation"] = str(validation_path)

    dataset = load_dataset("json", data_files=data_files)
    return dataset


def prepare_label_mapping(dataset: DatasetDict) -> tuple[DatasetDict, dict[int, str], dict[str, int]]:
    """Ensure labels are consecutive integers and return mapping dictionaries."""

    unique_labels = sorted(set(dataset["train"]["labels"]))
    id2label = {idx: str(label) for idx, label in enumerate(unique_labels)}
    label2id = {label: idx for idx, label in id2label.items()}

    def _map_labels(example):
        example["labels"] = label2id[str(example["labels"])]
        return example

    dataset = dataset.map(_map_labels)
    return dataset, id2label, label2id


def tokenize_dataset(dataset: DatasetDict, tokenizer: AutoTokenizer, max_length: int) -> DatasetDict:
    def preprocess_function(examples):
        tokenized = tokenizer(
            examples["text"],
            padding=False,
            truncation=True,
            max_length=max_length,
        )
        tokenized["labels"] = examples["labels"]
        return tokenized

    return dataset.map(preprocess_function, batched=True, remove_columns=["text"])

## Configure the run

Populate the paths and hyperparameters you want to use for fine-tuning.

In [None]:
args = ScriptArguments(
    model_name_or_path="bert-base-uncased",
    train_file=Path("path/to/train.jsonl"),
    validation_file=Path("path/to/valid.jsonl"),
    output_dir=Path("./model-output"),
    max_length=512,
    learning_rate=5e-5,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    num_train_epochs=3,
    weight_decay=0.0,
    warmup_ratio=0.0,
    logging_steps=50,
    eval_steps=None,
    seed=42,
)
args

## Load data and tokenizer

In [None]:
logger.info("Loading dataset...")
raw_dataset = load_json_dataset(args.train_file, args.validation_file)
tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path)
processed_dataset, id2label, label2id = prepare_label_mapping(raw_dataset)
num_labels = len(id2label)
tokenized_dataset = tokenize_dataset(processed_dataset, tokenizer, args.max_length)
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

## Set up evaluation metric

In [None]:
accuracy = evaluate.load("accuracy")

def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    preds = predictions.argmax(axis=-1)
    return accuracy.compute(predictions=preds, references=labels)

## Initialize model, trainer, and start fine-tuning

In [None]:
model = AutoModelForSequenceClassification.from_pretrained(
    args.model_name_or_path,
    num_labels=num_labels,
    id2label=id2label,
    label2id=label2id,
)

has_validation = args.validation_file is not None
evaluation_strategy = "steps" if (has_validation and args.eval_steps) else ("epoch" if has_validation else "no")

training_args = TrainingArguments(
    output_dir=str(args.output_dir),
    learning_rate=args.learning_rate,
    per_device_train_batch_size=args.per_device_train_batch_size,
    per_device_eval_batch_size=args.per_device_eval_batch_size,
    num_train_epochs=args.num_train_epochs,
    weight_decay=args.weight_decay,
    warmup_ratio=args.warmup_ratio,
    logging_steps=args.logging_steps,
    evaluation_strategy=evaluation_strategy,
    save_strategy="epoch",
    eval_steps=args.eval_steps if has_validation else None,
    load_best_model_at_end=has_validation,
    metric_for_best_model="accuracy" if has_validation else None,
    seed=args.seed,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_dataset["train"],
    eval_dataset=tokenized_dataset.get("validation"),
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)

trainer.train()
trainer.save_model()
tokenizer.save_pretrained(args.output_dir)