In [None]:
from datasets import load_dataset
import evaluate
import optuna
import torch

from transformers import AutoConfig
from transformers import AutoModelForSequenceClassification
from transformers import AutoTokenizer
from transformers import set_seed
from transformers import Trainer
from transformers import TrainingArguments


# Set seed for reproducibility
set_seed(42)

# Load IMDb dataset
dataset = load_dataset("imdb")

metric = evaluate.load("accuracy")  # Replaces deprecated load_metric

# Model name
model_name = "lvwerra/distilbert-imdb"

# Tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name)


def tokenize(batch):
    return tokenizer(batch["text"], padding=True, truncation=True)


dataset = dataset.map(tokenize, batched=True)
dataset = dataset.rename_column("label", "labels")
dataset.set_format("torch", columns=["input_ids", "attention_mask", "labels"])

# ➔ Define train and eval datasets here (before slicing)
train_dataset = dataset["train"]
eval_dataset = dataset["test"]

# Slice datasets for faster experiments
train_dataset = train_dataset.select(range(1500))
eval_dataset = eval_dataset.select(range(500))

# Model config
config = AutoConfig.from_pretrained(model_name, num_labels=2)


# Model initialization function
def model_init(trial):
    return AutoModelForSequenceClassification.from_pretrained(
        model_name,
        config=config,
    )


# Compute accuracy
def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = logits.argmax(axis=-1)
    return metric.compute(predictions=predictions, references=labels)


# Define Optuna search space
def optuna_hp_space(trial):
    return {
        "learning_rate": trial.suggest_float("learning_rate", 1e-6, 1e-4, log=True),
        "per_device_train_batch_size": trial.suggest_categorical(
            "per_device_train_batch_size", [8, 16]
        ),
        "num_train_epochs": trial.suggest_int("num_train_epochs", 2, 3),
    }


# Training arguments
best_run = trainer.hyperparameter_search(
    direction="maximize",
    backend="optuna",
    hp_space=optuna_hp_space,
    n_trials=5,
    compute_objective=compute_objective,
)

print(best_run)