### Installation

In [None]:
%%capture
import os
if "COLAB_" not in "".join(os.environ.keys()):
    !pip install unsloth
else:
    # Do this only in Colab notebooks! Otherwise use pip install unsloth
    !pip install --no-deps bitsandbytes accelerate xformers==0.0.29.post3 peft trl==0.15.2 triton cut_cross_entropy unsloth_zoo
    !pip install sentencepiece protobuf "datasets>=3.4.1" huggingface_hub hf_transfer
    !pip install transformers==4.51.3
    !pip install --no-deps unsloth
    !pip install optuna

### Unsloth

In [None]:
from unsloth import FastVisionModel # FastLanguageModel for LLMs
import torch

# 4bit pre quantized models we support for 4x faster downloading + no OOMs.
fourbit_models = [
    "unsloth/Llama-3.2-11B-Vision-Instruct-bnb-4bit", # Llama 3.2 vision support
    "unsloth/Llama-3.2-11B-Vision-bnb-4bit",
    "unsloth/Llama-3.2-90B-Vision-Instruct-bnb-4bit", # Can fit in a 80GB card!
    "unsloth/Llama-3.2-90B-Vision-bnb-4bit",

    "unsloth/Pixtral-12B-2409-bnb-4bit",              # Pixtral fits in 16GB!
    "unsloth/Pixtral-12B-Base-2409-bnb-4bit",         # Pixtral base model

    "unsloth/Qwen2-VL-2B-Instruct-bnb-4bit",          # Qwen2 VL support
    "unsloth/Qwen2-VL-7B-Instruct-bnb-4bit",
    "unsloth/Qwen2-VL-72B-Instruct-bnb-4bit",

    "unsloth/llava-v1.6-mistral-7b-hf-bnb-4bit",      # Any Llava variant works!
    "unsloth/llava-1.5-7b-hf-bnb-4bit",
] # More models at https://huggingface.co/unsloth

model, tokenizer = FastVisionModel.from_pretrained(
    "unsloth/Qwen2.5-VL-7B-Instruct-bnb-4bit",
    load_in_4bit = True, # Use 4bit to reduce memory use. False for 16bit LoRA.
    use_gradient_checkpointing = "unsloth", # True or "unsloth" for long context
)

In [None]:
model = FastVisionModel.get_peft_model(
    model,
    finetune_vision_layers     = True, # False if not finetuning vision layers
    finetune_language_layers   = True, # False if not finetuning language layers
    finetune_attention_modules = True, # False if not finetuning attention layers
    finetune_mlp_modules       = True, # False if not finetuning MLP layers

    r = 16,           # The larger, the higher the accuracy, but might overfit
    lora_alpha = 16,  # Recommended alpha == r at least
    lora_dropout = 0,
    bias = "none",
    random_state = 3407,
    use_rslora = False,  # We support rank stabilized LoRA
    loftq_config = None, # And LoftQ
    # target_modules = "all-linear", # Optional now! Can specify a list if needed
)

<a name="Convert"></a>
### Convert train and validation CSVs to JSON format, structuring each conversation as a dictionary suitable for training LLMs.

To format the data, all vision finetuning tasks should be formatted as follows:

```python
[
{ "role": "user",
  "content": [{"type": "text",  "text": Q}, {"type": "image", "image": image} ]
},
{ "role": "assistant",
  "content": [{"type": "text",  "text": A} ]
},
]
```

In [None]:
from google.colab import drive
drive.mount('/content/gdrive')

import pandas as pd
import json

def generate_json_from_csv(
    csv_path: str,
    output_json_path: str,
    gcs_bucket_name: str,
    image_column: str = "Image",
    label_column: str = "Category"
):
    df = pd.read_csv(csv_path)
    data = []

    for i, row in df.iterrows():
        img_path = str(row[image_column])

        # Construct GCS URL if not already a full URL
        if not img_path.startswith("http"):
            full_img_url = f"https://storage.googleapis.com/{gcs_bucket_name}/{str(row[label_column])}/{img_path}"
        else:
            full_img_url = img_path

        prompt = (
            "Analyze the provided image of an apple leaf using your computer vision capabilities. "
            "Classify the leaf into the most appropriate category based on its condition, "
            "choosing from the predefined list: "
            "{\n  \"categories\": [\n    \"black-rot\",\n    \"healthy\",\n    \"rust\",\n    \"scab\"\n  ]\n} "
            "Provide your final classification in the following JSON format without explanations: "
            "{\"category\": \"chosen_category_name\"}"
        )

        conversation = { # prompt-completion pairs
            "messages": [
                {
                    "role": "user",
                    "content": [
                        {"type": "text", "text": prompt},
                        {"type": "image", "image": full_img_url}
                    ]
                },
                {
                    "role": "assistant",
                    "content": [
                        {"type": "text", "text": f"{{\"category\": \"{row[label_column]}\"}}"}
                    ]
                }
            ]
        }

        data.append(conversation)

    # Save to JSON
    with open(output_json_path, "w", encoding="utf-8") as f:
        json.dump(data, f, ensure_ascii=False, indent=2)

    print(f"CSV data saved to {output_json_path}")

def prepare_test_set(input_csv, gcs_bucket_name, output_csv):
    df = pd.read_csv(input_csv)

    # Update the Image column
    df['Image'] = df.apply(
        lambda row: f"https://storage.googleapis.com/{gcs_bucket_name}/{row['Category']}/{row['Image']}", axis=1
    )

    # Save the updated DataFrame to output_csv
    df.to_csv(output_csv, index=False)

    print(f"Updated CSV saved as {output_csv}")

absolute_path = "/content/gdrive/My Drive/Projects/VL-Models/Datasets/"
gcs_bucket_name = "kroumeliotis-image-bucket/plant-disease-detection/Apple/256"

# generate_json_from_csv(
#     csv_path = absolute_path + "train_set.csv",
#     output_json_path = absolute_path + "train_set_256.json",
#     gcs_bucket_name = gcs_bucket_name
# )

# generate_json_from_csv(
#     csv_path = absolute_path + "validation_set.csv",
#     output_json_path = absolute_path + "validation_set_256.json",
#     gcs_bucket_name = gcs_bucket_name
# )

# prepare_test_set(
#     input_csv = absolute_path + "test_set.csv",
#     gcs_bucket_name = gcs_bucket_name,
#     output_csv = absolute_path + "test_set_256.csv"
# )

### Load train and validation sets into memory

In [None]:
import json
from google.colab import drive

drive.mount('/content/gdrive') # We can omit it on sequence run
absolute_path = '/content/gdrive/My Drive/Projects/VL-Models/Datasets/'  # We can omit it on sequence run

train_set = absolute_path + "train_set_256.json"
validation_set = absolute_path + "validation_set_256.json"

# Load the dataset into memory
with open(train_set, "r", encoding="utf-8") as f:
    train_set = json.load(f)

with open(validation_set, "r", encoding="utf-8") as f:
    validation_set = json.load(f)

print(train_set[0])
print(validation_set[0])

### Zero-shot Predictions

In [None]:
from PIL import Image
import requests
from io import BytesIO

FastVisionModel.for_inference(model) # Enable for inference!

sample = train_set[2]

# Extract user message contents
user_content = sample["messages"][0]["content"]

# Extract image URL and instruction from the nested structure
instruction = next(item["text"] for item in user_content if item["type"] == "text")
image_url = next(item["image"] for item in user_content if item["type"] == "image")
# Download and load the image from URL
response = requests.get(image_url)
image = Image.open(BytesIO(response.content)).convert("RGB")

messages = [
    {"role": "user", "content": [
        {"type": "image"},
        {"type": "text", "text": instruction}
    ]}
]

input_text = tokenizer.apply_chat_template(messages, add_generation_prompt = True)
inputs = tokenizer(
    image,
    input_text,
    add_special_tokens = False,
    return_tensors = "pt",
).to("cuda")

from transformers import TextStreamer
text_streamer = TextStreamer(tokenizer, skip_prompt = True)
_ = model.generate(**inputs, streamer = text_streamer, max_new_tokens = 128,
                   use_cache = True, temperature = 1.5, min_p = 0.1)

## Fine-Tuning Phase + Optuna

In [None]:
!pip install optuna
from google.colab import drive
drive.mount('/content/gdrive')
drive_path = "/content/gdrive/My Drive/Projects/VL-Models/Results/Qwen2.5-VL-7B-Instruct-bnb-4bit-4/"

import optuna
import time
import logging
import os
import pickle
import json
from typing import Dict, Any, Optional
from transformers import TrainerCallback, TrainerState, TrainerControl, EarlyStoppingCallback
import numpy as np
from unsloth import is_bf16_supported
from unsloth.trainer import UnslothVisionDataCollator
from trl import SFTTrainer, SFTConfig

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Persistence paths
STUDY_DB_PATH = f"{drive_path}optuna_study.db"
RESULTS_PATH = f"{drive_path}trial_results.json"
CHECKPOINT_PATH = f"{drive_path}optimization_checkpoint.pkl"

class PersistentMetricsTracker:
    """Track and store metrics for each trial with persistence"""

    def __init__(self, results_path: str):
        self.results_path = results_path
        self.trial_results = self.load_existing_results()

    def load_existing_results(self) -> list:
        """Load existing trial results from file"""
        if os.path.exists(self.results_path):
            try:
                with open(self.results_path, 'r') as f:
                    results = json.load(f)
                print(f"Loaded {len(results)} existing trial results")
                return results
            except Exception as e:
                print(f"Error loading existing results: {e}")
                return []
        return []

    def save_results(self):
        """Save current trial results to file"""
        try:
            with open(self.results_path, 'w') as f:
                json.dump(self.trial_results, f, indent=2)
        except Exception as e:
            print(f"Error saving results: {e}")

    def log_trial_result(self, trial_number: int, params: Dict[str, Any],
                        final_loss: float, training_time: float,
                        training_losses: list, validation_losses: list,
                        trial_state: str = "COMPLETE"):
        result = {
            'trial_number': trial_number,
            'params': params,
            'final_validation_loss': final_loss,
            'training_time_minutes': training_time / 60,
            'training_losses': training_losses,
            'validation_losses': validation_losses,
            'trial_state': trial_state,
            'timestamp': time.strftime('%Y-%m-%d %H:%M:%S')
        }

        # Check if trial already exists and update, otherwise append
        existing_trial_idx = None
        for i, existing_result in enumerate(self.trial_results):
            if existing_result['trial_number'] == trial_number:
                existing_trial_idx = i
                break

        if existing_trial_idx is not None:
            self.trial_results[existing_trial_idx] = result
        else:
            self.trial_results.append(result)

        # Save immediately after each trial
        self.save_results()

        # Print trial results
        print(f"\n{'='*60}")
        print(f"TRIAL {trial_number} {trial_state}")
        print(f"{'='*60}")
        print(f"Parameters:")
        for key, value in params.items():
            print(f"  {key}: {value}")
        if final_loss != float('inf'):
            print(f"Final Validation Loss: {final_loss:.6f}")
        print(f"Training Time: {training_time/60:.2f} minutes")
        print(f"{'='*60}\n")

    def get_completed_trial_numbers(self) -> set:
        """Get set of completed trial numbers"""
        completed = set()
        for result in self.trial_results:
            if result['trial_state'] == 'COMPLETE':
                completed.add(result['trial_number'])
        return completed

class OptunaPruningCallback(TrainerCallback):
    """Callback to enable Optuna pruning during training"""

    def __init__(self, trial, monitor_metric="eval_loss"):
        self.trial = trial
        self.monitor_metric = monitor_metric

    def on_evaluate(self, args, state, control, model=None, **kwargs):
        # Report intermediate value to Optuna for pruning
        if state.log_history:
            # Get the latest evaluation metrics
            latest_logs = state.log_history[-1]
            if self.monitor_metric in latest_logs:
                current_value = latest_logs[self.monitor_metric]
                self.trial.report(current_value, state.epoch)

                # Check if trial should be pruned
                if self.trial.should_prune():
                    raise optuna.TrialPruned(f"Trial pruned at epoch {state.epoch}")

def save_checkpoint(study, metrics_tracker, current_trial: int):
    """Save optimization checkpoint"""
    checkpoint_data = {
        'current_trial': current_trial,
        'study_trials_count': len(study.trials),
        'completed_trials': metrics_tracker.get_completed_trial_numbers(),
        'timestamp': time.strftime('%Y-%m-%d %H:%M:%S')
    }

    try:
        with open(CHECKPOINT_PATH, 'wb') as f:
            pickle.dump(checkpoint_data, f)
        print(f"Checkpoint saved at trial {current_trial}")
    except Exception as e:
        print(f"Error saving checkpoint: {e}")

def load_checkpoint() -> Optional[Dict]:
    """Load optimization checkpoint"""
    if os.path.exists(CHECKPOINT_PATH):
        try:
            with open(CHECKPOINT_PATH, 'rb') as f:
                checkpoint = pickle.load(f)
            print(f"Loaded checkpoint from {checkpoint['timestamp']}")
            return checkpoint
        except Exception as e:
            print(f"Error loading checkpoint: {e}")
    return None

def create_or_load_study():
    """Create new study or load existing one"""
    storage_url = f"sqlite:///{STUDY_DB_PATH}"
    study_name = "llama_hyperparameter_optimization"

    try:
        # Try to load existing study
        study = optuna.load_study(
            study_name=study_name,
            storage=storage_url
        )
        print(f"Loaded existing study with {len(study.trials)} trials")
        return study
    except KeyError:
        # Create new study if it doesn't exist
        print("Creating new study...")
        study = optuna.create_study(
            study_name=study_name,
            storage=storage_url,
            direction="minimize",
            pruner=optuna.pruners.MedianPruner(
                n_startup_trials=5,
                n_warmup_steps=5,
                interval_steps=2
            ),
            sampler=optuna.samplers.TPESampler(seed=3407),
            load_if_exists=True
        )
        return study

# Initialize persistent metrics tracker
metrics_tracker = PersistentMetricsTracker(RESULTS_PATH)

def objective(trial):
    """Objective function for Optuna optimization"""

    # Check if this trial was already completed
    completed_trials = metrics_tracker.get_completed_trial_numbers()
    if trial.number in completed_trials:
        print(f"Trial {trial.number} was already completed, skipping...")
        # Find the existing result and return its loss
        for result in metrics_tracker.trial_results:
            if (result['trial_number'] == trial.number and
                result['trial_state'] == 'COMPLETE'):
                return result['final_validation_loss']
        return float('inf')

    # Suggest hyperparameters
    # learning_rate = trial.suggest_float("learning_rate", 5e-5, 5e-4, log=True)
    learning_rate = trial.suggest_float("learning_rate", 1e-5, 5e-4, log=True)
    # per_device_train_batch_size = trial.suggest_categorical("per_device_train_batch_size", [1, 2, 4])
    per_device_train_batch_size = trial.suggest_categorical("per_device_train_batch_size", [2, 4])
    gradient_accumulation_steps = trial.suggest_categorical("gradient_accumulation_steps", [4, 8, 16])
    # warmup_ratio = trial.suggest_float("warmup_ratio", 0.0, 0.3)
    warmup_ratio = trial.suggest_float("warmup_ratio", 0.01, 0.1)
    # weight_decay = trial.suggest_float("weight_decay", 0.0, 0.1)
    weight_decay = trial.suggest_float("weight_decay", 0.0, 0.05)
    # num_train_epochs = trial.suggest_int("num_train_epochs", 2, 6)
    num_train_epochs = trial.suggest_int("num_train_epochs", 6, 15)

    # Calculate warmup steps based on dataset size and batch configuration
    total_samples = len(train_set)
    steps_per_epoch = total_samples // (per_device_train_batch_size * gradient_accumulation_steps)
    total_steps = steps_per_epoch * num_train_epochs
    warmup_steps = int(total_steps * warmup_ratio)

    print(f"\nStarting Trial {trial.number}")
    print(f"Parameters: {trial.params}")

    try:
        # Record start time
        start_time = time.time()

        # Enable model for training
        FastVisionModel.for_training(model)

        # Create trial-specific output directory
        trial_output_dir = f"{drive_path}outputs/trial_{trial.number}"
        os.makedirs(trial_output_dir, exist_ok=True)

        # Create trainer with suggested hyperparameters
        trainer = SFTTrainer(
            model=model,
            tokenizer=tokenizer,
            data_collator=UnslothVisionDataCollator(model, tokenizer),
            train_dataset=train_set,
            eval_dataset=validation_set,
            args=SFTConfig(
                per_device_train_batch_size=per_device_train_batch_size,
                gradient_accumulation_steps=gradient_accumulation_steps,
                warmup_steps=warmup_steps,
                num_train_epochs=num_train_epochs,
                learning_rate=learning_rate,
                fp16=not is_bf16_supported(),
                bf16=is_bf16_supported(),
                logging_steps=1,
                optim="adamw_8bit",
                weight_decay=weight_decay,
                lr_scheduler_type="linear",
                seed=3407,
                output_dir=trial_output_dir,
                report_to="none",

                # Evaluation settings for pruning
                eval_strategy="epoch",
                eval_steps=1,
                save_strategy= "epoch", #"no", do not save weights on each checkpoint
                save_total_limit=1,
                load_best_model_at_end=True,
                metric_for_best_model="eval_loss",
                greater_is_better=False,

                # Vision-specific settings
                remove_unused_columns=False,
                dataset_text_field="messages",
                dataset_kwargs={"skip_prepare_dataset": True},
                dataset_num_proc=4,
                max_seq_length=2048,
            ),
        )

        # Add callbacks for pruning and early stopping
        pruning_callback = OptunaPruningCallback(trial, monitor_metric="eval_loss")
        early_stopping_callback = EarlyStoppingCallback(
            early_stopping_patience=5,
            early_stopping_threshold=0.003
        )
        trainer.add_callback(pruning_callback)
        trainer.add_callback(early_stopping_callback)

        # Train the model
        trainer_stats = trainer.train()

        # Record end time
        end_time = time.time()
        training_time = end_time - start_time

        # Extract metrics from training history
        training_losses = []
        validation_losses = []

        for log_entry in trainer.state.log_history:
            if 'train_loss' in log_entry:
                training_losses.append(log_entry['train_loss'])
            if 'eval_loss' in log_entry:
                validation_losses.append(log_entry['eval_loss'])

        # Get final validation loss
        final_validation_loss = validation_losses[-1] if validation_losses else float('inf')

        # Log trial results with COMPLETE state
        metrics_tracker.log_trial_result(
            trial_number=trial.number,
            params=trial.params,
            final_loss=final_validation_loss,
            training_time=training_time,
            training_losses=training_losses,
            validation_losses=validation_losses,
            trial_state="COMPLETE"
        )

        return final_validation_loss

    except optuna.TrialPruned as e:
        end_time = time.time()
        training_time = end_time - start_time

        print(f"Trial {trial.number} was pruned early: {str(e)}")

        # Log pruned trial results
        metrics_tracker.log_trial_result(
            trial_number=trial.number,
            params=trial.params,
            final_loss=float('inf'),
            training_time=training_time,
            training_losses=[],
            validation_losses=[],
            trial_state="PRUNED"
        )
        raise

    except Exception as e:
        end_time = time.time()
        training_time = end_time - start_time

        print(f"Trial {trial.number} failed with error: {str(e)}")

        # Log failed trial results
        metrics_tracker.log_trial_result(
            trial_number=trial.number,
            params=trial.params,
            final_loss=float('inf'),
            training_time=training_time,
            training_losses=[],
            validation_losses=[],
            trial_state="FAILED"
        )
        return float('inf')

def run_optimization_with_resume(n_trials: int = 20):
    """Run optimization with resume capability"""

    # Create or load study
    study = create_or_load_study()

    # Load checkpoint if available
    checkpoint = load_checkpoint()
    start_trial = 0

    if checkpoint:
        completed_trials = metrics_tracker.get_completed_trial_numbers()
        start_trial = len(completed_trials)
        print(f"Resuming from trial {start_trial} (completed: {len(completed_trials)})")

    # Calculate remaining trials
    remaining_trials = max(0, n_trials - len(study.trials))

    if remaining_trials == 0:
        print(f"All {n_trials} trials already completed!")
        return study

    print(f"Running {remaining_trials} remaining trials...")
    print("Each trial will be saved immediately and can be resumed if interrupted.\n")

    # Custom optimization loop with checkpointing
    for i in range(remaining_trials):
        try:
            # Save checkpoint before each trial
            save_checkpoint(study, metrics_tracker, len(study.trials))

            # Only enqueue a starting point if it hasn't already been tried
            if len(study.trials) == 0:
                study.enqueue_trial({
                    "learning_rate": 2e-4,
                    "per_device_train_batch_size": 2,
                    "gradient_accumulation_steps": 8,
                    "warmup_ratio": 0.05,  # estimated from warmup_steps=5
                    "weight_decay": 0.01,
                    "num_train_epochs": 10
                })

            # Run single trial
            study.optimize(objective, n_trials=1, timeout=None)

            print(f"Progress: {len(study.trials)}/{n_trials} trials completed")

        except KeyboardInterrupt:
            print("\nOptimization interrupted by user")
            break
        except Exception as e:
            print(f"Error in trial: {e}")
            print("Continuing with next trial...")
            continue

    return study

# Main execution
if __name__ == "__main__":
    # Configuration
    n_trials = 50  # Adjust based on your needs

    print(f"Starting/Resuming Bayesian Optimization with up to {n_trials} trials...")
    print(f"Results will be saved to: {RESULTS_PATH}")
    print(f"Study database: {STUDY_DB_PATH}")
    print(f"Checkpoints: {CHECKPOINT_PATH}\n")

    # Run optimization
    study = run_optimization_with_resume(n_trials)

    # Print final results
    print(f"\n{'='*80}")
    print("OPTIMIZATION STATUS")
    print(f"{'='*80}")

    completed_trials = [t for t in study.trials if t.state == optuna.trial.TrialState.COMPLETE]
    pruned_trials = [t for t in study.trials if t.state == optuna.trial.TrialState.PRUNED]
    failed_trials = [t for t in study.trials if t.state == optuna.trial.TrialState.FAIL]

    print(f"Total trials: {len(study.trials)}")
    print(f"Completed trials: {len(completed_trials)}")
    print(f"Pruned trials: {len(pruned_trials)}")
    print(f"Failed trials: {len(failed_trials)}")

    if completed_trials:
        print(f"\nBest trial:")
        best_trial = study.best_trial
        print(f"  Value (Validation Loss): {best_trial.value:.6f}")
        print(f"  Trial Number: {best_trial.number}")
        print(f"  Params:")
        for key, value in best_trial.params.items():
            print(f"    {key}: {value}")

        # Print top 5 trials
        print(f"\nTop 5 trials:")
        top_trials = sorted(completed_trials, key=lambda x: x.value)[:5]

        for i, trial in enumerate(top_trials, 1):
            print(f"\n  Rank {i}:")
            print(f"    Validation Loss: {trial.value:.6f}")
            print(f"    Trial Number: {trial.number}")
            print(f"    Params: {trial.params}")

        # Option to train final model with best hyperparameters
        retrain_best_model = True  # Set to False if you don't want to retrain

        if retrain_best_model and completed_trials:
            print(f"\n{'='*60}")
            print("TRAINING FINAL MODEL WITH BEST HYPERPARAMETERS")
            print(f"{'='*60}")

            best_params = best_trial.params

            # Calculate warmup steps for best params
            total_samples = len(train_set)
            steps_per_epoch = total_samples // (best_params['per_device_train_batch_size'] *
                                               best_params['gradient_accumulation_steps'])
            total_steps = steps_per_epoch * best_params['num_train_epochs']
            warmup_steps = int(total_steps * best_params['warmup_ratio'])

            # Create final trainer
            FastVisionModel.for_training(model)

            final_output_dir = f"{drive_path}final_model"
            os.makedirs(final_output_dir, exist_ok=True)

            final_trainer = SFTTrainer(
                model=model,
                tokenizer=tokenizer,
                data_collator=UnslothVisionDataCollator(model, tokenizer),
                train_dataset=train_set,
                eval_dataset=validation_set,
                args=SFTConfig(
                    per_device_train_batch_size=best_params['per_device_train_batch_size'],
                    gradient_accumulation_steps=best_params['gradient_accumulation_steps'],
                    warmup_steps=warmup_steps,
                    num_train_epochs=best_params['num_train_epochs'],
                    learning_rate=best_params['learning_rate'],
                    fp16=not is_bf16_supported(),
                    bf16=is_bf16_supported(),
                    logging_steps=1,
                    optim="adamw_8bit",
                    weight_decay=best_params['weight_decay'],
                    lr_scheduler_type="linear",
                    seed=3407,
                    output_dir=final_output_dir,
                    report_to="none",

                    eval_strategy="epoch",
                    save_strategy="epoch",
                    save_total_limit=3,
                    load_best_model_at_end=True,
                    metric_for_best_model="eval_loss",
                    greater_is_better=False,

                    remove_unused_columns=False,
                    dataset_text_field="messages",
                    dataset_kwargs={"skip_prepare_dataset": True},
                    dataset_num_proc=4,
                    max_seq_length=2048,
                ),
            )

            # Add early stopping callback to final trainer
            final_early_stopping = EarlyStoppingCallback(
                early_stopping_patience=3,
                early_stopping_threshold=0.01
            )
            final_trainer.add_callback(final_early_stopping)

            # Train final model
            print("Training final model...")
            final_stats = final_trainer.train()

            # Save the best model
            best_model_path = f"{drive_path}best_model"
            os.makedirs(best_model_path, exist_ok=True)
            final_trainer.save_model(best_model_path)
            tokenizer.save_pretrained(best_model_path)

            print(f"Final model saved to '{best_model_path}'!")

    print(f"\n{'='*60}")
    print("HYPERPARAMETER OPTIMIZATION COMPLETE!")
    print(f"{'='*60}")

    # Save final comprehensive results
    final_results = {
        'study_trials': len(study.trials),
        'completed_trials': len(completed_trials),
        'best_params': study.best_trial.params if completed_trials else None,
        'best_loss': study.best_trial.value if completed_trials else None,
        'all_trial_results': metrics_tracker.trial_results
    }

    final_results_path = f'{drive_path}final_optimization_results.json'
    with open(final_results_path, 'w') as f:
        json.dump(final_results, f, indent=2)

    print(f"Complete results saved to '{final_results_path}'")