In [None]:
# -----------------------------------------------------------------------------
# BatteryMind – Hyper-parameter Tuning (Optuna) - Transformer Battery Predictor
# -----------------------------------------------------------------------------
# Cell 1 – Environment & Imports
import os, json, yaml, time, warnings, logging, random, tempfile
import numpy as np
import pandas as pd
import optuna
from sklearn.model_selection import train_test_split
from transformers import (
    AutoModelForSequenceClassification,
    AutoTokenizer,
    Trainer,
    TrainingArguments
)
from ai_models.transformers.battery_health_predictor.data_loader import (
    BatterySequenceDataset
)
from ai_models.utils.logging_utils import configure_logging
from ai_models.config import training_config_path  # YAML from config/

warnings.filterwarnings("ignore")
configure_logging()
logger = logging.getLogger("tuning")

# Cell 2 – Load Training Configuration
with open(training_config_path, "r") as f:
    cfg = yaml.safe_load(f)
DATA_PATH = cfg["data"]["telemetry_csv"]
MODEL_NAME = cfg["model"]["base_checkpoint"]
NUM_LABELS = cfg["model"]["num_labels"]

# Cell 3 – Load & Split Dataset
df = pd.read_csv(DATA_PATH)
train_df, valid_df = train_test_split(
    df, test_size=0.1, shuffle=False  # keep temporal order
)
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
train_ds = BatterySequenceDataset(train_df, tokenizer)
valid_ds = BatterySequenceDataset(valid_df, tokenizer)

# Cell 4 – Objective Function for Optuna
def objective(trial):
    hyperparams = {
        "learning_rate": trial.suggest_float("lr", 1e-5, 1e-3, log=True),
        "num_train_epochs": trial.suggest_int("epochs", 3, 15),
        "per_device_train_batch_size": trial.suggest_categorical(
            "batch", [8, 16, 32]
        ),
        "weight_decay": trial.suggest_float("wd", 0.0, 0.3),
        "warmup_ratio": trial.suggest_float("warmup", 0.0, 0.3),
        "gradient_accumulation_steps": trial.suggest_int("accum", 1, 4),
        "fp16": True,
    }
    args = TrainingArguments(
        output_dir=tempfile.mkdtemp(),
        evaluation_strategy="epoch",
        save_strategy="no",
        **hyperparams
    )
    model = AutoModelForSequenceClassification.from_pretrained(
        MODEL_NAME, num_labels=NUM_LABELS
    )
    trainer = Trainer(
        model=model,
        args=args,
        train_dataset=train_ds,
        eval_dataset=valid_ds,
        tokenizer=tokenizer,
    )
    metrics = trainer.train()
    eval_metric = metrics.training_loss
    return eval_metric

# Cell 5 – Run Study
study = optuna.create_study(direction="minimize")
study.optimize(objective, n_trials=50, timeout=2*60*60)  # 50 trials / 2 h max

# Cell 6 – Save Best Params
best_params = study.best_params
with open("best_transformer_params.json", "w") as f:
    json.dump(best_params, f, indent=2)
print("Best params:", best_params)

# Cell 7 – Train Final Model with Best Params
args_best = TrainingArguments(
    output_dir="./transformer_best",
    evaluation_strategy="epoch",
    save_total_limit=2,
    load_best_model_at_end=True,
    metric_for_best_model="loss",
    greater_is_better=False,
    logging_steps=100,
    fp16=True,
    **best_params
)
final_model = AutoModelForSequenceClassification.from_pretrained(
    MODEL_NAME, num_labels=NUM_LABELS
)
trainer_best = Trainer(
    model=final_model,
    args=args_best,
    train_dataset=train_ds,
    eval_dataset=valid_ds,
    tokenizer=tokenizer,
)
trainer_best.train()
trainer_best.save_model("./transformer_best")

# Cell 8 – Register Model Artefact
from ai_models.model_artifacts import (
    get_model_manager, create_model_metadata
)
mm = get_model_manager()
metadata = create_model_metadata(
    model_id=f"transformer_bhp_{int(time.time())}",
    model_type="transformer",
    version="1.0.1",
    name="Transformer BHP Opt-tuned",
    description="Optuna tuned Transformer Battery-Health predictor",
    training_metrics={"loss": study.best_value},
    training_dataset=DATA_PATH,
    training_duration_hours=study.study_duration.total_seconds()/3600,
    hyperparameters=best_params,
    created_by=os.getenv("USER", "notebook")
)
mm.register_model(metadata, {"model.pkl": "./transformer_best/pytorch_model.bin"})
print("Model registered:", metadata.model_id)
