# Text Classification with Trainer, MLFlow, & Optuna

This example notebook shows how to:
1. Train a basic text classification model using the Hugging Face `Trainer`
2. Use `optuna` to run a hyperparameter search

Both of these tasks use `mlflow` to log and track experiment runs so they're easily visualized in the MLflow UI

In [1]:
!pip install -q datasets transformers[torch] jupyter ipywidgets evaluate mlflow optuna

## Load dataset

In [1]:
from datasets import load_dataset

ds = load_dataset("imdb")

In [2]:
ds["train"][0]

{'text': 'I rented I AM CURIOUS-YELLOW from my video store because of all the controversy that surrounded it when it was first released in 1967. I also heard that at first it was seized by U.S. customs if it ever tried to enter this country, therefore being a fan of films considered "controversial" I really had to see this for myself.<br /><br />The plot is centered around a young Swedish drama student named Lena who wants to learn everything she can about life. In particular she wants to focus her attentions to making some sort of documentary on what the average Swede thought about certain political issues such as the Vietnam War and race issues in the United States. In between asking politicians and ordinary denizens of Stockholm about their opinions on politics, she has sex with her drama teacher, classmates, and married men.<br /><br />What kills me about I AM CURIOUS-YELLOW is that 40 years ago, this was considered pornographic. Really, the sex and nudity scenes are few and far be

## Preprocess dataset

In [3]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
tokenizer.model_max_length = 512  # distilbert specific

In [4]:
def preprocess_function(examples):
    return tokenizer(examples["text"], truncation=True)

In [5]:
tokenized_ds = ds.map(preprocess_function, batched=True)

In [6]:
from transformers import DataCollatorWithPadding

data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

## Fine-tune with Trainer

In [7]:
from transformers import AutoModelForSequenceClassification

model = AutoModelForSequenceClassification.from_pretrained(
    "distilbert-base-uncased", num_labels=2
)

Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [8]:
import numpy as np
import evaluate

metric = evaluate.load("accuracy")


def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)
    return metric.compute(predictions=predictions, references=labels)

In [9]:
import os
from transformers import TrainingArguments, Trainer, set_seed

set_seed(42)
os.environ["MLFLOW_EXPERIMENT_NAME"] = "single-training-run"
# os.environ["MLFLOW_TRACKING_URI"] = "<YOUR_MLFLOW_SERVER_URI>""
DEV_MODE = True

train_ds = (
    tokenized_ds["train"]
    if not DEV_MODE
    else tokenized_ds["train"].shuffle(seed=42).select(range(100))
)
test_ds = (
    tokenized_ds["test"]
    if not DEV_MODE
    else tokenized_ds["test"].shuffle(seed=42).select(range(100))
)

training_args = TrainingArguments(
    output_dir="./output",
    learning_rate=2e-5,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    num_train_epochs=5,
    weight_decay=0.01,
    evaluation_strategy="epoch",
    logging_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
    metric_for_best_model="accuracy",
    save_total_limit=1,
    report_to="mlflow",
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_ds,
    eval_dataset=test_ds,
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)

In [10]:
train_result = trainer.train()

2024/04/24 21:04:46 INFO mlflow.tracking.fluent: Experiment with name 'single-training-run' does not exist. Creating a new experiment.


Epoch,Training Loss,Validation Loss,Accuracy
1,0.6871,0.676819,0.53
2,0.6611,0.658556,0.62
3,0.5936,0.636642,0.76
4,0.5682,0.618597,0.77
5,0.5084,0.61193,0.76


## HPO Sweep using Optuna

In [13]:
os.environ["MLFLOW_EXPERIMENT_NAME"] = "hpo-sweep"


def model_init(trial):
    return AutoModelForSequenceClassification.from_pretrained(
        "distilbert-base-uncased", num_labels=2
    )


# Define the objective function for the Optuna study
# Notice how we set the learning_rate hyperparameter using the trial object below
def objective(trial):
    training_args = TrainingArguments(
        output_dir=f"./output/trial-{trial.number}",
        run_name=f"trial-{trial.number}",
        learning_rate=trial.suggest_float(
            "learning_rate", 1e-5, 1e-3, log=True
        ),  # Define hyperparameter space using the trial object
        per_device_train_batch_size=8,
        num_train_epochs=5,
        weight_decay=0.01,
        evaluation_strategy="epoch",
        logging_strategy="epoch",
        save_strategy="epoch",
        load_best_model_at_end=True,
        metric_for_best_model="accuracy",
        save_total_limit=1,
        report_to="mlflow",
    )

    # Initialize the Trainer with the current set of hyperparameters
    trainer = Trainer(
        model_init=model_init,
        args=training_args,
        train_dataset=train_ds,
        eval_dataset=test_ds,
        tokenizer=tokenizer,
        data_collator=data_collator,
        compute_metrics=compute_metrics,
    )

    # Train the model
    trainer.train()

    return trainer.state.best_metric

In [14]:
import optuna

# Set up the Optuna study
study = optuna.create_study(direction="maximize")
study.optimize(objective, n_trials=3)

[I 2024-04-24 21:05:46,383] A new study created in memory with name: no-name-b4f955cb-e6e4-4ea9-be72-0fce6a85058d


Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
2024/04/24 21:05:47 INFO mlflow.tracking.fluent: Experiment with name 'hpo-sweep' does not exist. Creating a new experiment.


Epoch,Training Loss,Validation Loss,Accuracy
1,0.8284,0.699672,0.47
2,0.7256,0.699299,0.53
3,0.7198,0.692394,0.53
4,0.7299,0.70013,0.53
5,0.6885,0.686363,0.53


[I 2024-04-24 21:06:20,488] Trial 0 finished with value: 0.53 and parameters: {'learning_rate': 0.0002792661161912271}. Best is trial 0 with value: 0.53.
Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Epoch,Training Loss,Validation Loss,Accuracy
1,0.706,0.688791,0.53
2,0.6508,0.690406,0.53
3,0.5733,0.674137,0.6
4,0.4995,0.654065,0.68
5,0.4291,0.649037,0.68


[I 2024-04-24 21:06:53,702] Trial 1 finished with value: 0.68 and parameters: {'learning_rate': 2.4050600466309405e-05}. Best is trial 1 with value: 0.68.
Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Epoch,Training Loss,Validation Loss,Accuracy
1,0.8509,0.700729,0.47
2,0.7182,0.71163,0.53
3,0.7404,0.696181,0.47
4,0.7204,0.709847,0.53
5,0.6979,0.693116,0.53


[I 2024-04-24 21:07:26,905] Trial 2 finished with value: 0.53 and parameters: {'learning_rate': 0.0006531910660327685}. Best is trial 1 with value: 0.68.


In [15]:
study.trials_dataframe()

Unnamed: 0,number,value,datetime_start,datetime_complete,duration,params_learning_rate,state
0,0,0.53,2024-04-24 21:05:46.384553,2024-04-24 21:06:20.488047,0 days 00:00:34.103494,0.000279,COMPLETE
1,1,0.68,2024-04-24 21:06:20.488696,2024-04-24 21:06:53.702450,0 days 00:00:33.213754,2.4e-05,COMPLETE
2,2,0.53,2024-04-24 21:06:53.703321,2024-04-24 21:07:26.905253,0 days 00:00:33.201932,0.000653,COMPLETE


In [16]:
study.best_trial

FrozenTrial(number=1, state=1, values=[0.68], datetime_start=datetime.datetime(2024, 4, 24, 21, 6, 20, 488696), datetime_complete=datetime.datetime(2024, 4, 24, 21, 6, 53, 702450), params={'learning_rate': 2.4050600466309405e-05}, user_attrs={}, system_attrs={}, intermediate_values={}, distributions={'learning_rate': FloatDistribution(high=0.001, log=True, low=1e-05, step=None)}, trial_id=1, value=None)

After running the hyperparameter search, launch the MLFlow UI locally by navigating to the directory where `mlruns` is stored, then running `mlflow ui`, **OR** if you've configured `MLFLOW_TRACKING_URI` environment variable, your logs will display there.

Should look like the following:

![](images/mlflow.png)