# Excercise: RandomSearch with Optuna
In this exercise we will make a very simple finetuning example of a tiny language model. The focus here is on hyperparameter optimization.

## Base model
We will use [roneneldan/TinyStories-1M](https://huggingface.co/roneneldan/TinyStories-1M), a 1 million parameter model trained to write short children stories in a simple English.

In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedModel, PreTrainedTokenizer, Trainer, TrainingArguments

def get_model_and_tokenizer() -> tuple[PreTrainedModel, PreTrainedTokenizer]:
    model_name = "roneneldan/TinyStories-1M"
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    model = AutoModelForCausalLM.from_pretrained(model_name)
    return model, tokenizer

## Dataset for supervised finetuning
A dataset has been prepared in the project directory called NamedTinyStories dataset. The dataset is based on asking a bigger TinyStories model generate stories based on a certain name which is pulled from the most popular names in the United States of the last century.

In [None]:
#from typing import Sequence
from collections import defaultdict

from datasets import Dataset, load_from_disk
from torch.nn.utils.rnn import pad_sequence
from transformers import PreTrainedTokenizer

def get_dataset() -> Dataset:
    return load_from_disk("/mimer/NOBACKUP/groups/llm-workshop/datasets/NamedTinyStories/dataset")

def collate_fn(examples: list[dict], tokenizer: PreTrainedTokenizer):
    newline_ids = tokenizer("\n", add_special_tokens=False)["input_ids"]
    batch = defaultdict(list)
    for example in examples:
        story_ids = tokenizer(example["story"], add_special_tokens=False)["input_ids"]
        name_ids = tokenizer(example["name"], add_special_tokens=False)["input_ids"]
        name_ids += [tokenizer.eos_token_id]

        input_ids = story_ids + newline_ids + name_ids
        attention_mask = [1] * len(input_ids)
        labels = [-100] * (len(story_ids) + len(newline_ids)) + name_ids


        batch["input_ids"].append(input_ids)
        batch["attention_mask"].append(attention_mask)
        batch["labels"].append(labels)

    # Pad labels manually (unknown for tokenizer)
    max_seq_len = max(len(l) for l in batch["labels"])
    batch["labels"] = [[-100] * (max_seq_len - len(l)) + l for l in batch["labels"]]
    batch = tokenizer.pad(batch, return_tensors="pt", padding=True, padding_side="left")
    return batch


## Add metrics of interest
This is optional as we will be using the loss in this case, but cross entropy loss is not very human readable so let's add some metrics that can be easier to follow during training.

In [None]:
from transformers import EvalPrediction, TrainerCallback, pipeline

class ComputeMetrics:
    def __init__(self, batchwise: bool):
        self.batchwise = batchwise
        self.reset()

    def reset(self):
        self.correct = 0
        self.total = 0

    def __call__(self, eval_pred: EvalPrediction, compute_result: bool = False):
        logits, labels = eval_pred

        # Get most likely prediction and shift to match prediction with label
        preds = logits.argmax(dim=-1)[:, :-1]
        labels = labels[:, 1:]

        # Add to counters
        self.correct += ((preds == labels) | (labels == -100)).all(dim=1).sum()
        self.total += preds.size(0)
        
        if compute_result or not self.batchwise:
            # Return result
            accuracy = self.correct / self.total
            self.reset()
            return {"accuracy": float(accuracy)}

        return {}

## Enable Optuna pruning
Optuna can optionally use pruning methods. For this work with HuggingFace trainer we will implement a callback that can abort a session based on the evaluation score.

In [None]:
from optuna import TrialPruned
from optuna.trial import Trial
from transformers import TrainerCallback


class OptunaPruningCallback(TrainerCallback):
    """
    A callback that reports metrics to Optuna and enables pruning of unpromising trials.
    """

    def __init__(self, trial: Trial, metric_to_optimize: str):
        super().__init__()
        self.trial = trial
        self.metric_to_optimize = metric_to_optimize

    def on_evaluate(self, args, state, control, metrics=None, **kwargs):
        """
        Reports the main metric to Optuna and triggers pruning if necessary.
        """
        if metrics is None:
            return control

        value = metrics[self.metric_to_optimize]
        step = int(state.global_step)
        self.trial.report(value, step=step)

        # Check if this trial should be pruned
        if self.trial.should_prune():
            # Raise Optuna's pruning exception to halt training gracefully
            raise TrialPruned(f"Trial was pruned at step {step}.")

        return control

## Specifying the objective of the hyperparameter optimization
Now we want to have something to track when do our hyper parameter optimization.

In [None]:
from functools import partial

from optuna.trial import Trial
from torch import Generator
from torch.utils.data import random_split
from transformers import PreTrainedModel, Trainer, TrainingArguments

def finetune_stories(trial: Trial) -> float:
    model, tokenizer = get_model_and_tokenizer()
    
    # HyperParameters
    lr = trial.suggest_float("learning_rate", 1e-6, 1e-4, log=True)
    weight_decay = trial.suggest_float("weight_decay", 1e-6, 1e-2, log=True)
    
    # Not tuned, but let's log them to Optuna anyway as they impact performance
    trial.set_user_attr("batch_size", 16)
    trial.set_user_attr("train_dataset_size", 0.8)
    trial.set_user_attr("eval_dataset_size", 0.2)
    trial.set_user_attr("train_eval_split_seed", 10371)
    trial.set_user_attr("num_epochs", 30)
    trial.set_user_attr("metric_to_optimize", "eval_loss")

    # Initialize datasets
    ds_rng = Generator(device="cpu")
    ds_rng.manual_seed(trial.user_attrs["train_eval_split_seed"])
    train_dataset, eval_dataset = random_split(
        get_dataset(),
        lengths=[
            trial.user_attrs["train_dataset_size"],
            trial.user_attrs["eval_dataset_size"],
        ],
        generator=ds_rng,
    )
    #train_dataset = NamedStoryDataset(tokenizer=tokenizer, length=trial.user_attrs["train_dataset_size"])
    #eval_dataset = NamedStoryDataset(tokenizer=tokenizer, length=trial.user_attrs["eval_dataset_size"])
    
    # Define training parameters
    training_args = TrainingArguments(
        output_dir=f"./{trial.study.study_name}_trial{trial.number:03d}",
        learning_rate=lr,
        weight_decay=weight_decay,
        per_device_train_batch_size=trial.user_attrs["batch_size"],
        per_device_eval_batch_size=trial.user_attrs["batch_size"],
        dataloader_drop_last=True,
        eval_strategy="epoch",
        batch_eval_metrics=True,
        logging_strategy="epoch",
        num_train_epochs=trial.user_attrs["num_epochs"],
        remove_unused_columns=False,  # is done by custom collator instead
        fp16=True,
        dataloader_pin_memory=True,
        save_strategy="no",
        push_to_hub=False,
        disable_tqdm=True,
        report_to=[],
    )
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        data_collator=partial(collate_fn, tokenizer=tokenizer),
        compute_metrics=ComputeMetrics(batchwise=training_args.batch_eval_metrics),
        callbacks=[OptunaPruningCallback(trial, metric_to_optimize=trial.user_attrs["metric_to_optimize"])],
    )

    # Run training
    trainer.train()

    # Get metric to optimize
    best_metrics = sorted(
        [
            log_entry for log_entry in trainer.state.log_history
            if isinstance(log_entry, dict)
        ],
        key=lambda metrics: metrics.get(trial.user_attrs["metric_to_optimize"], float("inf"))
    )[0]
    trial.set_user_attr("best_metrics", best_metrics)
    return best_metrics[trial.user_attrs["metric_to_optimize"]]

## Run the hyperparameter optimization
Only one thing left to do now.

In [None]:
from optuna import create_study, samplers, pruners
from optuna.storages import JournalStorage
from optuna.storages.journal import JournalFileBackend

storage = JournalStorage(
    JournalFileBackend("./optuna_journal_storage.log"),
    #JournalFileBackend("/mimer/NOBACKUP/groups/...")
)

study = create_study(
    study_name="NamedStories_v0.1",
    storage=storage,
    load_if_exists=True,
    sampler=samplers.RandomSampler(),  # https://optuna.readthedocs.io/en/stable/reference/samplers/index.html
    pruner=pruners.NopPruner(),  # https://optuna.readthedocs.io/en/stable/reference/pruners.html
    direction="minimize",
)
study.optimize(finetune_stories, n_trials=1)

In [None]:
# This cell is just to stop execution at this point if not running in Jupyter
import sys

from IPython import get_ipython
from ipykernel.zmqshell import ZMQInteractiveShell

if not isinstance(get_ipython(), ZMQInteractiveShell):
    get_ipython().ask_exit()

## Inspect results

In [None]:
from optuna import load_study
from optuna.storages import JournalStorage
from optuna.storages.journal import JournalFileBackend

storage = JournalStorage(
    JournalFileBackend("./optuna_journal_storage.log"),
    #JournalFileBackend("/mimer/NOBACKUP/groups/...")
)

study = load_study(
    study_name="NamedStories_v0.1",
    storage=storage,
)

In [None]:
import plotly.offline as pyo

pyo.init_notebook_mode()

In [None]:
from optuna.visualization import plot_contour

plot_contour(study)

In [None]:
from optuna.visualization import plot_param_importances

plot_param_importances(study)

In [None]:
float("inf")