# %% [markdown]
# # Hyperparameter Optimization with Optuna
# 
# This notebook uses Optuna for automated hyperparameter tuning:
# - Learning rate optimization
# - Batch size and gradient accumulation
# - LoRA rank and alpha
# - Warmup and weight decay
# - Early stopping and scheduling


In [None]:
# %%
# Import libraries and configuration
import os
import torch
import optuna
from optuna.visualization import plot_optimization_history, plot_param_importances
from datasets import load_from_disk
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    TrainingArguments,
    Trainer,
    DataCollatorForLanguageModeling,
)
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
import matplotlib.pyplot as plt

# Configuration
BASE_MODEL = "deepseek-ai/DeepSeek-V3-Base"
DATASET_PATH = "../data/processed/sft_dataset"
OUTPUT_DIR = "../models/optuna_trials"
N_TRIALS = 20
TIMEOUT = 3600 * 6  # 6 hours

os.makedirs(OUTPUT_DIR, exist_ok=True)

# %% [markdown]
# ## Load Dataset


In [None]:
# %%
# Load and prepare dataset
dataset = load_from_disk(DATASET_PATH)
train_dataset = dataset["train"]
val_dataset = dataset["validation"]

# Use smaller subset for faster tuning
train_subset = train_dataset.select(range(min(5000, len(train_dataset))))
val_subset = val_dataset.select(range(min(500, len(val_dataset))))

print(f"Training subset: {len(train_subset)} samples")
print(f"Validation subset: {len(val_subset)} samples")

# %% [markdown]
# ## Define Objective Function

In [None]:
# %%
# Optuna objective function
def objective(trial):
    """
    Objective function for Optuna optimization
    """
    # Suggest hyperparameters
    learning_rate = trial.suggest_float("learning_rate", 1e-5, 5e-4, log=True)
    batch_size = trial.suggest_categorical("batch_size", [2, 4, 8])
    gradient_accumulation_steps = trial.suggest_categorical(
        "gradient_accumulation_steps", [2, 4, 8]
    )
    warmup_ratio = trial.suggest_float("warmup_ratio", 0.0, 0.1)
    weight_decay = trial.suggest_float("weight_decay", 0.0, 0.1)

    # LoRA hyperparameters
    lora_r = trial.suggest_categorical("lora_r", [8, 16, 32, 64])
    lora_alpha = trial.suggest_categorical("lora_alpha", [16, 32, 64])
    lora_dropout = trial.suggest_float("lora_dropout", 0.0, 0.2)

    # Scheduler
    lr_scheduler_type = trial.suggest_categorical(
        "lr_scheduler_type", ["linear", "cosine", "cosine_with_restarts"]
    )

    # Load tokenizer
    tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    # Load model with 4-bit quantization
    from transformers import BitsAndBytesConfig

    bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=torch.float16,
        bnb_4bit_use_double_quant=True,
    )

    model = AutoModelForCausalLM.from_pretrained(
        BASE_MODEL,
        quantization_config=bnb_config,
        device_map="auto",
        trust_remote_code=True,
    )

    model = prepare_model_for_kbit_training(model)

    # Configure LoRA
    lora_config = LoraConfig(
        r=lora_r,
        lora_alpha=lora_alpha,
        target_modules=["q_proj", "v_proj", "k_proj", "o_proj"],
        lora_dropout=lora_dropout,
        bias="none",
        task_type="CAUSAL_LM",
    )

    model = get_peft_model(model, lora_config)

    # Training arguments
    training_args = TrainingArguments(
        output_dir=f"{OUTPUT_DIR}/trial_{trial.number}",
        num_train_epochs=1,  # Short epoch for fast tuning
        per_device_train_batch_size=batch_size,
        per_device_eval_batch_size=batch_size,
        gradient_accumulation_steps=gradient_accumulation_steps,
        learning_rate=learning_rate,
        warmup_ratio=warmup_ratio,
        weight_decay=weight_decay,
        lr_scheduler_type=lr_scheduler_type,
        fp16=True,
        gradient_checkpointing=True,
        optim="paged_adamw_8bit",
        logging_steps=50,
        eval_steps=200,
        save_steps=200,
        evaluation_strategy="steps",
        save_total_limit=1,
        load_best_model_at_end=False,
        report_to="none",
        remove_unused_columns=False,
    )

    # Data collator
    data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)

    # Trainer
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_subset,
        eval_dataset=val_subset,
        data_collator=data_collator,
    )

    # Train
    trainer.train()

    # Evaluate
    eval_results = trainer.evaluate()
    eval_loss = eval_results["eval_loss"]

    # Clean up
    del model
    del trainer
    torch.cuda.empty_cache()

    return eval_loss

# %% [markdown]
# ## Run Optimization

In [None]:
# %%
# Create and run Optuna study
study = optuna.create_study(
    direction="minimize",
    study_name="sft_hyperparameter_tuning",
    storage=f"sqlite:///{OUTPUT_DIR}/optuna_study.db",
    load_if_exists=True,
)

# Run optimization
print(f"Starting optimization with {N_TRIALS} trials...\n")
study.optimize(objective, n_trials=N_TRIALS, timeout=TIMEOUT, show_progress_bar=True)

print("\n" + "=" * 60)
print("Optimization completed!")
print("=" * 60)

# %% [markdown]
# ## Results Analysis

In [None]:
# %%
# Display best trial results
best_trial = study.best_trial

print(f"\nBest Trial: {best_trial.number}")
print(f"Best Validation Loss: {best_trial.value:.4f}")
print(f"\nBest Hyperparameters:")
for key, value in best_trial.params.items():
    print(f"  {key}: {value}")

In [None]:
# %%
# Visualize optimization results
# Optimization history
fig = plot_optimization_history(study)
fig.show()

# Parameter importances
fig = plot_param_importances(study)
fig.show()

In [None]:
# %%
# Save best hyperparameters
import json

best_params = {
    "best_trial": best_trial.number,
    "best_loss": best_trial.value,
    "params": best_trial.params,
    "n_trials": len(study.trials),
}

with open(f"{OUTPUT_DIR}/best_hyperparameters.json", "w") as f:
    json.dump(best_params, f, indent=2)

print(f"\nBest hyperparameters saved to: {OUTPUT_DIR}/best_hyperparameters.json")

# %% [markdown]
# ## Trial Results DataFrame

In [None]:
# %%
# Create and display trials DataFrame
import pandas as pd

trials_df = study.trials_dataframe()
trials_df = trials_df.sort_values("value")

print("\nTop 5 Trials:")
print(trials_df.head())

# Save to CSV
trials_df.to_csv(f"{OUTPUT_DIR}/trials_results.csv", index=False)
print(f"\nAll trials saved to: {OUTPUT_DIR}/trials_results.csv")