In [None]:
!pip install -q transformers datasets sentencepiece tf-keras

In [None]:
!pip install optuna

In [None]:
# -*- coding: utf-8 -*-
import logging

from datasets import load_dataset, DatasetDict
from transformers import (
    BertTokenizerFast,
    EncoderDecoderModel,
    DataCollatorForSeq2Seq,
    Seq2SeqTrainer,
    Seq2SeqTrainingArguments,
)
import optuna

# ─── 1) SETUP LOGGING ──────────────────────────────────────────────────────────
logging.basicConfig(
    format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s",
    datefmt="%Y/%m/%d %H:%M:%S",
    level=logging.INFO,
)
logger = logging.getLogger(__name__)

# ─── 2) LOAD & SPLIT EUROPARL EN–ES ────────────────────────────────────────────
logger.info("Loading Europarl English–Spanish dataset…")
raw = load_dataset("europarl_bilingual", "en-es")
if "validation" not in raw:
    logger.info("Creating a 10% validation split…")
    split = raw["train"].train_test_split(test_size=0.1, seed=42)
    raw = DatasetDict({
        "train": split["train"],
        "validation": split["test"],
        "test": raw.get("test",
                        split["train"].train_test_split(test_size=0.2, seed=42)["test"])
    })

# ─── 3) SUBSAMPLE FOR SPEED ──────────────────────────────────────────────────
max_train, max_val = 30_000, 3_000
if len(raw["train"]) > max_train:
    raw["train"] = raw["train"].select(range(max_train))
if len(raw["validation"]) > max_val:
    raw["validation"] = raw["validation"].select(range(max_val))

# ─── 4) TOKENIZATION ──────────────────────────────────────────────────────────
tokenizer = BertTokenizerFast.from_pretrained("bert-base-multilingual-cased")
max_len = 128

def preprocess(batch, idxs):
    logger.info(f"Tokenizing examples {idxs[0]}–{idxs[-1]}…")
    inputs  = [t["en"] for t in batch["translation"]]
    targets = [t["es"] for t in batch["translation"]]
    enc = tokenizer(inputs,  max_length=max_len, truncation=True, padding="max_length")
    with tokenizer.as_target_tokenizer():
        lbl = tokenizer(targets, max_length=max_len, truncation=True, padding="max_length")
    enc["labels"] = lbl["input_ids"]
    return enc

tokenized = raw.map(
    preprocess,
    batched=True,
    batch_size=5000,
    with_indices=True,
    remove_columns=raw["train"].column_names,
)

# ─── 5) DATA COLLATOR ─────────────────────────────────────────────────────────
data_collator = DataCollatorForSeq2Seq(tokenizer, model=None, padding="longest")

# ─── 6) MODEL INIT ────────────────────────────────────────────────────────────
def model_init():
    m = EncoderDecoderModel.from_encoder_decoder_pretrained(
        "bert-base-multilingual-cased",
        "bert-base-multilingual-cased",
        tie_encoder_decoder=True,
    )

    # ─── Enable true seq2seq decoder with cross‐attention:
    m.config.decoder.is_decoder        = True
    m.config.decoder.add_cross_attention = True

    # ─── Special tokens & lengths
    m.config.decoder_start_token_id = tokenizer.cls_token_id
    m.config.eos_token_id           = tokenizer.sep_token_id
    m.config.pad_token_id           = tokenizer.pad_token_id
    m.config.max_length             = 128
    m.config.min_length             = 10
    m.config.no_repeat_ngram_size   = 3

    return m

# ─── 7) HYPERPARAMETER SPACE ─────────────────────────────────────────────────
def hp_space(trial: optuna.Trial):
    return {
        "learning_rate":               trial.suggest_loguniform("learning_rate", 1e-6, 5e-5),
        # smaller batch‐size choices to avoid OOM
        "per_device_train_batch_size": trial.suggest_categorical(
            "per_device_train_batch_size", [4, 8, 16]
        ),
        "weight_decay":                trial.suggest_uniform("weight_decay", 0.0, 0.3),
        "warmup_steps":                trial.suggest_int("warmup_steps", 0, 1000),
        "num_train_epochs":            trial.suggest_categorical("num_train_epochs", [2, 3, 4]),
    }

# ─── 8) TUNING ARGS ────────────────────────────────────────────────────────────
tuning_args = Seq2SeqTrainingArguments(
    output_dir="./hp_tuning",
    per_device_train_batch_size=8,      # default, overridden in hp_space
    per_device_eval_batch_size=8,
    evaluation_strategy="steps",
    eval_steps=500,
    save_strategy="steps",
    save_steps=500,
    save_total_limit=3,
    logging_steps=100,
    load_best_model_at_end=True,
    metric_for_best_model="eval_loss",
    greater_is_better=False,

    # Use HF’s Torch AdamW
    optim="adamw_torch",

    # Temporarily disable mixed precision until HPO+AMP bug is fixed
    fp16=False,
)

# ─── 9) TRAINER & HPO RUN ─────────────────────────────────────────────────────
trainer = Seq2SeqTrainer(
    model_init      = model_init,
    args            = tuning_args,
    train_dataset   = tokenized["train"],
    eval_dataset    = tokenized["validation"],
    data_collator   = data_collator,
    tokenizer       = tokenizer,
    compute_metrics = None,  # replace with your BLEU fn if desired
)

best = trainer.hyperparameter_search(
    direction="minimize",
    backend="optuna",
    hp_space=hp_space,
    n_trials=20,
    n_jobs=1,                       
    pruner=optuna.pruners.MedianPruner(),
    study_name="bert_translation_hp",
)

print("Best hyperparameters:", best.hyperparameters)
