In [None]:
# bert_encoder_from_scratch_with_pooling_multitype_allpairs_bo.py
# Your original code (MLM + NSP + MoE) refactored + Bayesian Optimization (Gaussian Process + EI)
#  - 3 epochs per BO trial
#  - 10 BO iterations
#  - Logs saved to bo_logs/

import random
import math
import os
import json
from typing import List, Tuple, Dict, Any
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, random_split
from gensim.models import Word2Vec
from datetime import datetime

# for Gaussian Process BO
from sklearn.gaussian_process import GaussianProcessRegressor
from sklearn.gaussian_process.kernels import Matern, WhiteKernel, ConstantKernel as C
from scipy.stats import norm

# -------------------------
# Config (DEFAULTS)
# -------------------------
VOCAB_MIN_FREQ = 1
MAX_SEQ_LEN = 1024
HIDDEN_SIZE = 768
NUM_LAYERS = 12
NUM_HEADS = 12
FFN_DIM = 3072
DROPOUT = 0.1
WORD2VEC_SIZE = HIDDEN_SIZE
WORD2VEC_WINDOW = 5
WORD2VEC_MIN_COUNT = 1
MLM_MASK_PROB = 0.15
BATCH_SIZE = 8
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
LEARNING_RATE = 2e-5

# -------------------------
# Special tokens
# -------------------------
PAD_TOKEN = "[PAD]"
CLS_TOKEN = "[CLS]"
SEP_TOKEN = "[SEP]"
MASK_TOKEN = "[MASK]"
UNK_TOKEN = "[UNK]"
SPECIAL_TOKENS = [PAD_TOKEN, CLS_TOKEN, SEP_TOKEN, MASK_TOKEN, UNK_TOKEN]

# -------------------------
# BO Settings
# -------------------------
BO_ITERATIONS = 10
BO_INIT_POINTS = 3   # random initial points
TRIAL_EPOCHS = 3     # you asked for 3 epochs per evaluation
BO_LOG_DIR = "bo_logs"

os.makedirs(BO_LOG_DIR, exist_ok=True)
os.makedirs(os.path.join(BO_LOG_DIR, "best_model"), exist_ok=True)

# -------------------------
# Utility: Vocab builder
# -------------------------
def build_vocab(sentences: List[str], min_freq: int = VOCAB_MIN_FREQ):
    from collections import Counter
    token_counts = Counter()
    for s in sentences:
        tokens = s.strip().split()
        token_counts.update(tokens)
    stoi, itos = {}, []
    for t in SPECIAL_TOKENS:
        stoi[t] = len(itos)
        itos.append(t)
    for token, cnt in token_counts.items():
        if cnt >= min_freq and token not in stoi:
            stoi[token] = len(itos)
            itos.append(token)
    return stoi, itos

# -------------------------
# Train or load Word2Vec
# -------------------------
def train_word2vec(sentences: List[str], vector_size=WORD2VEC_SIZE, window=WORD2VEC_WINDOW, min_count=WORD2VEC_MIN_COUNT, epochs=5):
    tokenized = [s.strip().split() for s in sentences]
    w2v = Word2Vec(sentences=tokenized, vector_size=vector_size, window=window, min_count=min_count, epochs=epochs, sg=0)
    return w2v

def build_embedding_matrix(w2v: Word2Vec, itos: List[str], hidden_size: int):
    vocab_size = len(itos)
    embeddings = np.random.normal(scale=0.02, size=(vocab_size, hidden_size)).astype(np.float32)
    for idx, tok in enumerate(itos):
        if tok in w2v.wv:
            vec = w2v.wv[tok]
            if vec.shape[0] != hidden_size:
                vec = vec[:hidden_size] if vec.shape[0] >= hidden_size else np.pad(vec, (0, hidden_size - vec.shape[0]))
            embeddings[idx] = vec
    pad_idx = itos.index(PAD_TOKEN)
    embeddings[pad_idx] = np.zeros(hidden_size, dtype=np.float32)
    return torch.tensor(embeddings)

# -------------------------
# Dataset (supports queries and chunks)
# -------------------------
class BertPretrainingDataset(Dataset):
    def __init__(self, data: List[Tuple[str, str]], stoi: dict, max_seq_len=MAX_SEQ_LEN):
        """
        data: list of tuples [(text, discriminator)], where discriminator ∈ {'Q', 'C'}
        """
        self.stoi = stoi
        self.max_seq_len = max_seq_len
        self.samples = []

        for text, dtype in data:
            if dtype == "Q":
                # Single-sentence query (MLM only)
                self.samples.append((text, dtype, None, None))
            elif dtype == "C":
                # Split chunk into sentences
                sents = [s.strip() for s in text.strip().split('.') if s.strip()]
                if len(sents) < 2:
                    sents = sents + sents  # duplicate if only one sentence
                # Positive pairs: consecutive sentences
                for i in range(len(sents) - 1):
                    self.samples.append((sents[i], "C", sents[i + 1], 1))
                # Negative pairs: non-consecutive
                for i in range(len(sents)):
                    for j in range(len(sents)):
                        if abs(i - j) > 1:  # skip consecutive
                            self.samples.append((sents[i], "C", sents[j], 0))

    def __len__(self):
        return len(self.samples)

    def _tokenize_to_ids(self, text: str) -> List[int]:
        toks = text.strip().split()
        return [self.stoi.get(t, self.stoi[UNK_TOKEN]) for t in toks]

    def __getitem__(self, idx):
        sent_a, dtype, sent_b, nsp_label = self.samples[idx]

        # -------------------------------
        # Case 1: Query (MLM only)
        # -------------------------------
        if dtype == 'Q':
            ids = self._tokenize_to_ids(sent_a)
            ids = ids[:self.max_seq_len - 2]
            input_ids = [self.stoi[CLS_TOKEN]] + ids + [self.stoi[SEP_TOKEN]]
            token_type_ids = [0] * len(input_ids)
            nsp_label = -100  # dummy
            return {
                "input_ids": torch.tensor(input_ids, dtype=torch.long),
                "token_type_ids": torch.tensor(token_type_ids, dtype=torch.long),
                "nsp_label": torch.tensor(nsp_label, dtype=torch.long),
                "batch_type": "Q"
            }

        # -------------------------------
        # Case 2: Chunk (MLM + NSP)
        # -------------------------------
        elif dtype == 'C':
            ids_a = self._tokenize_to_ids(sent_a)
            ids_b = self._tokenize_to_ids(sent_b)
            while len(ids_a) + len(ids_b) > self.max_seq_len - 3:
                if len(ids_a) > len(ids_b):
                    ids_a.pop()
                else:
                    ids_b.pop()
            input_ids = [self.stoi[CLS_TOKEN]] + ids_a + [self.stoi[SEP_TOKEN]] + ids_b + [self.stoi[SEP_TOKEN]]
            token_type_ids = [0] * (len(ids_a) + 2) + [1] * (len(ids_b) + 1)
            return {
                "input_ids": torch.tensor(input_ids, dtype=torch.long),
                "token_type_ids": torch.tensor(token_type_ids, dtype=torch.long),
                "nsp_label": torch.tensor(nsp_label, dtype=torch.long),
                "batch_type": "C"
            }

def collate_fn(batch, pad_id):
    input_ids_list = [b["input_ids"] for b in batch]
    token_type_list = [b["token_type_ids"] for b in batch]
    nsp_labels = torch.stack([b["nsp_label"] for b in batch]).long()
    batch_types = [b["batch_type"] for b in batch]

    max_len = max([x.size(0) for x in input_ids_list])
    padded_input_ids, padded_token_types, attention_masks = [], [], []
    for ids, tt in zip(input_ids_list, token_type_list):
        pad_len = max_len - ids.size(0)
        padded_input_ids.append(F.pad(ids, (0, pad_len), value=pad_id))
        padded_token_types.append(F.pad(tt, (0, pad_len), value=0))
        attention_masks.append((F.pad(ids, (0, pad_len), value=pad_id) != pad_id).long())

    return {
        "input_ids": torch.stack(padded_input_ids),
        "token_type_ids": torch.stack(padded_token_types),
        "attention_mask": torch.stack(attention_masks),
        "nsp_labels": nsp_labels,
        "batch_type": batch_types
    }

# -------------------------
# MLM Masking
# -------------------------
def create_mlm_labels_and_masked_input(input_ids, pad_id, mask_token_id, vocab_size, mask_prob=MLM_MASK_PROB):
    batch_size, seq_len = input_ids.shape
    mlm_labels = torch.full_like(input_ids, -100)
    prob_matrix = torch.full((batch_size, seq_len), mask_prob)
    prob_matrix[input_ids == pad_id] = 0.0
    special_upper = len(SPECIAL_TOKENS)
    prob_matrix[input_ids < special_upper] = 0.0
    masked_positions = torch.bernoulli(prob_matrix).bool()
    mlm_labels[masked_positions] = input_ids[masked_positions]
    input_ids_masked = input_ids.clone()
    rand_for_replace = torch.rand_like(input_ids, dtype=torch.float)
    mask_replace = masked_positions & (rand_for_replace < 0.8)
    random_replace = masked_positions & (rand_for_replace >= 0.8) & (rand_for_replace < 0.9)
    input_ids_masked[mask_replace] = mask_token_id
    if random_replace.any():
        count = int(random_replace.sum().item())
        rand_tokens = torch.randint(len(SPECIAL_TOKENS), vocab_size, (count,), device=input_ids.device)
        input_ids_masked[random_replace] = rand_tokens
    return input_ids_masked, mlm_labels

# -------------------------
# Mixture-of-Experts Module (unchanged logic)
# -------------------------
class MoE(nn.Module):
    def __init__(self, hidden_size, ffn_dim, num_experts=5, k=2, noise_std=1.0):
        super().__init__()
        self.hidden_size = hidden_size
        self.ffn_dim = ffn_dim
        self.num_experts = num_experts
        self.k = k
        self.noise_std = noise_std

        # experts: each expert is a small Feed-Forward Network (H -> ffn_dim -> H)
        self.experts = nn.ModuleList([
            nn.Sequential(
                nn.Linear(hidden_size, ffn_dim),
                nn.GELU(),
                nn.Linear(ffn_dim, hidden_size)
            ) for _ in range(num_experts)
        ])

        # router: maps hidden vector to expert logits
        self.router = nn.Linear(hidden_size, num_experts)

    def forward(self, x, mask=None):
        """
        x: (B, S, H)
        returns: out (B, S, H), aux_loss (scalar)
        """
        B, S, H = x.size()
        # ---- router logits (noiseless, for load-balancing) ----
        logits = self.router(x)  # (B, S, E)
        # soft probabilities for load balancing (use non-noisy softmax)
        probs_all = F.softmax(logits, dim=-1)  # (B, S, E)
        # importance per expert:
        importance = probs_all.sum(dim=(0, 1))  # (E,)
        total_tokens = float(B * S)
        # aux_loss encourages balanced importance across experts
        aux_loss = (self.num_experts * (importance / total_tokens).pow(2).sum())

        # ---- noisy logits for selection (only add noise during training) ----
        if self.training:
            noise = torch.randn_like(logits) * self.noise_std
            logits_noisy = logits + noise
        else:
            logits_noisy = logits

        # top-k selection on noisy logits
        topk_vals, topk_idx = torch.topk(logits_noisy, self.k, dim=-1)  # shapes (B,S,k)
        # convert topk vals to normalized weights via softmax over k
        topk_weights = F.softmax(topk_vals, dim=-1)  # (B,S,k)

        # Compute each expert's output on the full x (inefficient but simple)
        expert_outs = []
        for e in range(self.num_experts):
            expert_outs.append(self.experts[e](x))  # (B,S,H)
        expert_stack = torch.stack(expert_outs, dim=2)  # (B,S,E,H)

        # Build a gating tensor of shape (B,S,E) with nonzero entries only at topk indices
        device = x.device
        gating = torch.zeros(B, S, self.num_experts, device=device, dtype=x.dtype)  # float
        # scatter the topk_weights into gating at positions topk_idx
        # topk_idx: (B,S,k), topk_weights: (B,S,k)
        # We can flatten and scatter
        flat_idx = topk_idx.view(-1, self.k)  # (B*S, k)
        flat_w = topk_weights.view(-1, self.k)  # (B*S, k)
        # For each row r in [0..B*S-1], scatter into gating_flat[r, idx] = weight
        gating_flat = gating.view(-1, self.num_experts)  # (B*S, E)
        rows = torch.arange(gating_flat.size(0), device=device).unsqueeze(1).expand(-1, self.k)  # (B*S, k)
        gating_flat.scatter_(1, flat_idx, flat_w)
        gating = gating_flat.view(B, S, self.num_experts)  # (B,S,E)

        # Combine experts: out[b,s,:] = sum_e gating[b,s,e] * expert_stack[b,s,e,:]
        out = torch.einsum('bse,bseh->bsh', gating, expert_stack)  # (B,S,H)

        return out, aux_loss

# -------------------------
# Transformer encoder (unchanged logic)
# -------------------------
class TransformerEncoderLayer(nn.Module):
    def __init__(self, hidden_size, num_heads, ffn_dim, dropout=0.1, moe_experts=5, moe_k=2):
        super().__init__()
        self.self_attn = nn.MultiheadAttention(hidden_size, num_heads, dropout=dropout, batch_first=True)
        self.ln1 = nn.LayerNorm(hidden_size)
        self.ln2 = nn.LayerNorm(hidden_size)
        # Replace ffn with MoE module
        self.ffn_moe = MoE(hidden_size, ffn_dim, num_experts=moe_experts, k=moe_k)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask):
        key_padding_mask = (mask == 0)
        attn_out, _ = self.self_attn(x, x, x, key_padding_mask=key_padding_mask)
        x = self.ln1(x + self.dropout(attn_out))
        # MoE FFN
        ffn_out, aux_loss = self.ffn_moe(x, mask)
        x = self.ln2(x + self.dropout(ffn_out))
        return x, aux_loss

class BertEncoderModel(nn.Module):
    def __init__(self, vocab_size, hidden_size=HIDDEN_SIZE, num_layers=NUM_LAYERS, num_heads=NUM_HEADS, ffn_dim=FFN_DIM, max_position_embeddings=512, pad_token_id=0, embedding_weights=None, moe_experts=5, moe_k=2):
        super().__init__()
        self.pad_token_id = pad_token_id
        self.token_embeddings = nn.Embedding(vocab_size, hidden_size, padding_idx=pad_token_id)
        if embedding_weights is not None:
            self.token_embeddings.weight.data.copy_(embedding_weights)
        self.position_embeddings = nn.Embedding(max_position_embeddings, hidden_size)
        self.segment_embeddings = nn.Embedding(2, hidden_size)
        self.emb_ln = nn.LayerNorm(hidden_size)
        self.emb_dropout = nn.Dropout(0.1)
        self.layers = nn.ModuleList([TransformerEncoderLayer(hidden_size, num_heads, ffn_dim, dropout=DROPOUT, moe_experts=moe_experts, moe_k=moe_k) for _ in range(num_layers)])
        self.nsp_classifier = nn.Sequential(nn.Linear(hidden_size, hidden_size), nn.Tanh(), nn.Linear(hidden_size, 2))
        self.mlm_bias = nn.Parameter(torch.zeros(vocab_size))

    def encode(self, ids, tt=None, mask=None):
        if tt is None:
            tt = torch.zeros_like(ids)
        if mask is None:
            mask = (ids != self.pad_token_id).long()
        pos = torch.arange(ids.size(1), device=ids.device).unsqueeze(0)
        x = self.token_embeddings(ids) + self.position_embeddings(pos) + self.segment_embeddings(tt)
        x = self.emb_dropout(self.emb_ln(x))
        total_aux = 0.0
        for layer in self.layers:
            x, aux = layer(x, mask)
            total_aux = total_aux + aux
        return x, total_aux
    def forward(self, ids, tt=None, mask=None):
        seq_out, total_aux = self.encode(ids, tt, mask)
        pooled = seq_out[:, 0]
        nsp_logits = self.nsp_classifier(pooled)
        mlm_logits = F.linear(seq_out, self.token_embeddings.weight, self.mlm_bias)
        return mlm_logits, nsp_logits, total_aux

# -------------------------
# Training / evaluation helpers (refactors of your loop)
# -------------------------
def train_one_epoch(model: nn.Module, dataloader: DataLoader, optimizer: torch.optim.Optimizer,
                    mlm_loss_fct, nsp_loss_fct, pad_id: int, mask_id: int, vocab_size: int, device: torch.device,
                    aux_coeff: float):
    model.train()
    total_loss = 0.0
    total_steps = 0
    for batch in dataloader:
        ids = batch["input_ids"].to(device)
        tts = batch["token_type_ids"].to(device)
        mask = batch["attention_mask"].to(device)
        nsp_labels = batch["nsp_labels"].to(device)
        btypes = batch["batch_type"]
        ids_masked, mlm_labels = create_mlm_labels_and_masked_input(ids, pad_id, mask_id, vocab_size)
        ids_masked, mlm_labels = ids_masked.to(device), mlm_labels.to(device)

        optimizer.zero_grad()
        mlm_logits, nsp_logits, aux_loss = model(ids_masked, tts, mask)
        mlm_loss = mlm_loss_fct(mlm_logits.view(-1, mlm_logits.size(-1)), mlm_labels.view(-1))
        if all(bt == "C" for bt in btypes):
            nsp_loss = nsp_loss_fct(nsp_logits.view(-1, 2), nsp_labels.view(-1))
        else:
            nsp_loss = torch.tensor(0.0, device=device)

        loss = mlm_loss + nsp_loss + aux_coeff * aux_loss
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        total_steps += 1

    avg_loss = total_loss / max(1, total_steps)
    return avg_loss

def evaluate_model(model: nn.Module, dataloader: DataLoader, pad_id: int, mask_id: int, vocab_size: int, device: torch.device):
    model.eval()
    total_loss = 0.0
    total_steps = 0
    mlm_loss_fct = nn.CrossEntropyLoss(ignore_index=-100)
    nsp_loss_fct = nn.CrossEntropyLoss()
    aux_coeff = 0.01
    with torch.no_grad():
        for batch in dataloader:
            ids = batch["input_ids"].to(device)
            tts = batch["token_type_ids"].to(device)
            mask = batch["attention_mask"].to(device)
            nsp_labels = batch["nsp_labels"].to(device)
            btypes = batch["batch_type"]
            ids_masked, mlm_labels = create_mlm_labels_and_masked_input(ids, pad_id, mask_id, vocab_size)
            ids_masked, mlm_labels = ids_masked.to(device), mlm_labels.to(device)

            mlm_logits, nsp_logits, aux_loss = model(ids_masked, tts, mask)
            mlm_loss = mlm_loss_fct(mlm_logits.view(-1, mlm_logits.size(-1)), mlm_labels.view(-1))
            if all(bt == "C" for bt in btypes):
                nsp_loss = nsp_loss_fct(nsp_logits.view(-1, 2), nsp_labels.view(-1))
            else:
                nsp_loss = torch.tensor(0.0, device=device)

            loss = mlm_loss + nsp_loss + aux_coeff * aux_loss
            total_loss += loss.item()
            total_steps += 1

    avg_loss = total_loss / max(1, total_steps)
    return avg_loss

# -------------------------
# Helper: build data, vocab, embeddings, dataloaders given hyperparams
# -------------------------
def prepare_data_and_model(corpus: List[Tuple[str, str]],
                           stoi_override: Dict[str,int]=None,
                           hyperparams: Dict[str, Any]=None):
    """
    Returns:
      stoi, itos, vocab_size, emb (tensor), pad_id, mask_id, ds (dataset)
    hyperparams may alter word2vec_window, WORD2VEC_SIZE, WORD2VEC_MIN_COUNT, MAX_SEQ_LEN, etc.
    """
    # Unpack hyperparams or use defaults
    hv = hyperparams or {}
    word2vec_window = int(hv.get("word2vec_window", WORD2VEC_WINDOW))
    word2vec_size = int(hv.get("word2vec_size", WORD2VEC_SIZE))
    word2vec_min_count = int(hv.get("word2vec_min_count", WORD2VEC_MIN_COUNT))
    max_seq_len = int(hv.get("max_seq_len", MAX_SEQ_LEN))

    # Build vocab (we'll build from corpus text)
    texts = [x[0] for x in corpus]
    stoi, itos = build_vocab(texts, min_freq=VOCAB_MIN_FREQ)
    vocab_size = len(itos)

    # Word2Vec train
    w2v = train_word2vec(texts, vector_size=word2vec_size, window=word2vec_window, min_count=word2vec_min_count, epochs=5)
    emb = build_embedding_matrix(w2v, itos, HIDDEN_SIZE)
    pad_id = stoi[PAD_TOKEN]; mask_id = stoi[MASK_TOKEN]
    ds = BertPretrainingDataset(corpus, stoi, max_seq_len=max_seq_len)
    return stoi, itos, vocab_size, emb, pad_id, mask_id, ds

# -------------------------
# BO helpers: search space sampling, GP + EI
# -------------------------
def sample_random_hyperparams():
    """
    Sample hyperparameters in the agreed search space.
    """
    sample = {}
    # learning_rate ∈ [1e-6, 5e-5] (log-uniform)
    lr = 10 ** np.random.uniform(np.log10(1e-6), np.log10(5e-5))
    sample["learning_rate"] = float(lr)

    # moe_experts ∈ {3,4,5,6}
    sample["moe_experts"] = int(np.random.choice([3,4,5,6]))

    # moe_k ∈ {1,2}
    sample["moe_k"] = int(np.random.choice([1,2]))

    # ffn_dim ∈ [1024,4096] (discrete set multiples of 256)
    sample["ffn_dim"] = int(int(np.random.choice(np.arange(1024, 4097, 256))))

    # num_layers ∈ {6,8,12}
    sample["num_layers"] = int(np.random.choice([6,8,12]))

    # num_heads ∈ {6,8,12}
    sample["num_heads"] = int(np.random.choice([6,8,12]))

    # word2vec_window ∈ {3,5,7}
    sample["word2vec_window"] = int(np.random.choice([3,5,7]))

    # mlm_mask_prob ∈ [0.10,0.20]
    sample["mlm_mask_prob"] = float(np.random.uniform(0.10, 0.20))

    return sample

def hyperparams_to_vector(hp: Dict[str,Any]) -> np.ndarray:
    """
    Convert hyperparameter dict to numeric vector for GP.
    We'll use a consistent ordering:
    [log_lr, moe_experts, moe_k, ffn_dim/256, num_layers, num_heads, word2vec_window, mlm_mask_prob]
    """
    vec = []
    vec.append(np.log10(hp["learning_rate"]))
    vec.append(float(hp["moe_experts"]))
    vec.append(float(hp["moe_k"]))
    vec.append(float(hp["ffn_dim"]) / 256.0)
    vec.append(float(hp["num_layers"]))
    vec.append(float(hp["num_heads"]))
    vec.append(float(hp["word2vec_window"]))
    vec.append(float(hp["mlm_mask_prob"]))
    return np.array(vec, dtype=float)

def expected_improvement(mu: np.ndarray, sigma: np.ndarray, best: float, xi: float = 0.01):
    """
    Expected Improvement (for minimization: improvement = best - f).
    mu, sigma: arrays of shape (n_candidates,)
    best: current best (lower is better)
    We compute EI = E[max(best - f, 0)] where f ~ N(mu, sigma^2)
    """
    sigma = np.maximum(sigma, 1e-9)
    imp = best - mu - xi
    Z = imp / sigma
    ei = imp * norm.cdf(Z) + sigma * norm.pdf(Z)
    ei[sigma == 0.0] = 0.0
    return ei

# -------------------------
# Objective function for BO
# -------------------------
def objective_function(hp: Dict[str,Any],
                       corpus: List[Tuple[str,str]],
                       run_id: int,
                       verbose: bool = True) -> float:
    """
    Given hyperparameters, train the model for TRIAL_EPOCHS and return the validation total loss.
    This function:
      - Prepares data + embeddings (word2vec) depending on hp
      - Builds model with hp architecture
      - Trains for TRIAL_EPOCHS
      - Evaluates on validation split and returns avg total loss (MLM + NSP + aux)
    """
    # Prepare data and model
    stoi, itos, vocab_size, emb, pad_id, mask_id, ds = prepare_data_and_model(corpus, hyperparams=hp)

    # split dataset
    total_len = len(ds)
    if total_len < 2:
        # fallback: use full dataset as train and val same (rare with tiny test corpus)
        train_ds = ds
        val_ds = ds
    else:
        test_len = max(1, total_len // 5)
        train_len = total_len - test_len
        train_ds, val_ds = random_split(ds, [train_len, test_len])

    # dataloaders
    batch_size = int(hp.get("batch_size", BATCH_SIZE))
    dl_train = DataLoader(train_ds, batch_size=batch_size, shuffle=True, collate_fn=lambda b: collate_fn(b, pad_id))
    dl_val = DataLoader(val_ds, batch_size=batch_size, shuffle=False, collate_fn=lambda b: collate_fn(b, pad_id))

    # Build model with hyperparameters
    hidden_size = HIDDEN_SIZE  # keep same hidden size to avoid dimension mismatch in embedding
    model = BertEncoderModel(
        vocab_size=vocab_size,
        hidden_size=hidden_size,
        num_layers=int(hp["num_layers"]),
        num_heads=int(hp["num_heads"]),
        ffn_dim=int(hp["ffn_dim"]),
        max_position_embeddings=MAX_SEQ_LEN,
        pad_token_id=pad_id,
        embedding_weights=emb,
        moe_experts=int(hp["moe_experts"]),
        moe_k=int(hp["moe_k"])
    ).to(DEVICE)

    # Losses and optimizer
    mlm_loss_fct = nn.CrossEntropyLoss(ignore_index=-100)
    nsp_loss_fct = nn.CrossEntropyLoss()
    lr = float(hp["learning_rate"])
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr)

    # override global MLM mask prob for the create_mlm labels function usage
    # We will pass the value in calls to create_mlm_labels_and_masked_input by monkey-patching default global variable used there.
    global MLM_MASK_PROB
    prev_mlm_mask_prob = MLM_MASK_PROB
    MLM_MASK_PROB = float(hp.get("mlm_mask_prob", MLM_MASK_PROB))

    aux_coeff = 0.01  # keep same scaling as before

    try:
        # Train for TRIAL_EPOCHS
        for epoch in range(TRIAL_EPOCHS):
            train_loss = train_one_epoch(model, dl_train, optimizer, mlm_loss_fct, nsp_loss_fct,
                                         pad_id, mask_id, vocab_size, DEVICE, aux_coeff)
            if verbose:
                print(f"[Run {run_id}] Epoch {epoch+1}/{TRIAL_EPOCHS} - train_loss: {train_loss:.6f}")

        # Evaluate on validation
        val_loss = evaluate_model(model, dl_val, pad_id, mask_id, vocab_size, DEVICE)
        if verbose:
            print(f"[Run {run_id}] Validation loss: {val_loss:.6f}")

    finally:
        # restore MLM_MASK_PROB global
        MLM_MASK_PROB = prev_mlm_mask_prob

    # free memory
    torch.cuda.empty_cache()

    return float(val_loss)

# -------------------------
# Main: Bayesian Optimization loop
# -------------------------
def run_bayesian_optimization(corpus: List[Tuple[str,str]]):
    # records
    bo_records = []
    X = []
    y = []

    # initial random evaluations
    print("[BO] Starting initial random evaluations...")
    for i in range(BO_INIT_POINTS):
        hp = sample_random_hyperparams()
        # also include batch_size and maybe word2vec_size for completeness
        hp["batch_size"] = BATCH_SIZE
        hp["word2vec_size"] = WORD2VEC_SIZE
        hp["word2vec_min_count"] = WORD2VEC_MIN_COUNT
        print(f"[BO] Init sample {i+1}/{BO_INIT_POINTS}: {hp}")
        loss_val = objective_function(hp, corpus, run_id=i+1, verbose=True)
        vec = hyperparams_to_vector(hp)
        X.append(vec)
        y.append(loss_val)
        rec = {"iteration": i+1, "params": hp, "loss": float(loss_val)}
        bo_records.append(rec)
        # save intermediate logs
        with open(os.path.join(BO_LOG_DIR, "bo_results.json"), "w") as f:
            json.dump(bo_records, f, indent=2)

    # GP surrogate
    kernel = C(1.0, (1e-3, 1e3)) * Matern(length_scale=np.ones(X[0].shape[0]), nu=2.5) + WhiteKernel(noise_level=1e-6)
    gp = GaussianProcessRegressor(kernel=kernel, alpha=1e-6, normalize_y=True, n_restarts_optimizer=3, random_state=42)

    # main BO iterations
    for it in range(BO_INIT_POINTS, BO_ITERATIONS):
        print(f"[BO] Iteration {it+1}/{BO_ITERATIONS}")

        # fit GP on current data
        X_arr = np.vstack(X)
        y_arr = np.array(y)
        gp.fit(X_arr, y_arr)

        # propose many random candidates and pick one that maximizes EI
        n_candidates = 200
        candidates = []
        cand_vecs = []
        for _ in range(n_candidates):
            c = sample_random_hyperparams()
            c["batch_size"] = BATCH_SIZE
            c["word2vec_size"] = WORD2VEC_SIZE
            c["word2vec_min_count"] = WORD2VEC_MIN_COUNT
            candidates.append(c)
            cand_vecs.append(hyperparams_to_vector(c))
        cand_vecs = np.vstack(cand_vecs)

        mu, sigma = gp.predict(cand_vecs, return_std=True)
        best_so_far = np.min(y_arr)
        eis = expected_improvement(mu, sigma, best_so_far, xi=0.01)
        best_idx = int(np.argmax(eis))
        chosen_hp = candidates[best_idx]
        print(f"[BO] Chosen hyperparams (EI max): {chosen_hp}")

        # evaluate chosen configuration
        loss_val = objective_function(chosen_hp, corpus, run_id=it+1, verbose=True)
        vec = hyperparams_to_vector(chosen_hp)
        X.append(vec)
        y.append(loss_val)

        rec = {"iteration": it+1, "params": chosen_hp, "loss": float(loss_val)}
        bo_records.append(rec)
        # save logs
        with open(os.path.join(BO_LOG_DIR, "bo_results.json"), "w") as f:
            json.dump(bo_records, f, indent=2)

    # compute best config
    losses = [r["loss"] for r in bo_records]
    best_idx = int(np.argmin(losses))
    best_record = bo_records[best_idx]
    best_hp = best_record["params"]
    best_loss = best_record["loss"]
    print(f"[BO] Best config found: {best_hp} with loss {best_loss:.6f}")

    # save best config to file
    with open(os.path.join(BO_LOG_DIR, "best_config.json"), "w") as f:
        json.dump({"best_params": best_hp, "best_loss": best_loss}, f, indent=2)

    # Retrain final model on full data with best hyperparams (optional)
    print("[BO] Training final model on full dataset with best hyperparameters...")
    # Prepare final data and model
    stoi, itos, vocab_size, emb, pad_id, mask_id, ds = prepare_data_and_model(corpus, hyperparams=best_hp)
    dl_full = DataLoader(ds, batch_size=int(best_hp.get("batch_size", BATCH_SIZE)), shuffle=True, collate_fn=lambda b: collate_fn(b, pad_id))

    final_model = BertEncoderModel(
        vocab_size=vocab_size,
        hidden_size=HIDDEN_SIZE,
        num_layers=int(best_hp["num_layers"]),
        num_heads=int(best_hp["num_heads"]),
        ffn_dim=int(best_hp["ffn_dim"]),
        max_position_embeddings=MAX_SEQ_LEN,
        pad_token_id=pad_id,
        embedding_weights=emb,
        moe_experts=int(best_hp["moe_experts"]),
        moe_k=int(best_hp["moe_k"])
    ).to(DEVICE)

    # optimizer
    lr = float(best_hp["learning_rate"])
    optimizer = torch.optim.AdamW(final_model.parameters(), lr=lr)
    mlm_loss_fct = nn.CrossEntropyLoss(ignore_index=-100)
    nsp_loss_fct = nn.CrossEntropyLoss()
    # set global MLM mask prob to chosen
    global MLM_MASK_PROB
    prev_mlm_mask_prob = MLM_MASK_PROB
    MLM_MASK_PROB = float(best_hp.get("mlm_mask_prob", MLM_MASK_PROB))
    aux_coeff = 0.01

    try:
        # train final for TRIAL_EPOCHS (you can increase if desired)
        for epoch in range(TRIAL_EPOCHS):
            t_loss = train_one_epoch(final_model, dl_full, optimizer, mlm_loss_fct, nsp_loss_fct, pad_id, mask_id, vocab_size, DEVICE, aux_coeff)
            print(f"[Final Train] Epoch {epoch+1}/{TRIAL_EPOCHS} - loss {t_loss:.6f}")
    finally:
        MLM_MASK_PROB = prev_mlm_mask_prob

    # save final model and vocab into bo_logs/best_model/
    save_dir = os.path.join(BO_LOG_DIR, "best_model")
    os.makedirs(save_dir, exist_ok=True)
    torch.save(final_model.state_dict(), os.path.join(save_dir, "bert_encoder.pt"))
    with open(os.path.join(save_dir, "vocab.json"), "w") as f:
        json.dump({"stoi": stoi, "itos": itos}, f)
    print(f"[BO] Best model and vocab saved to {save_dir}")

    return best_record

# -------------------------
# Example corpus (keeps your original small corpus)
# -------------------------
def get_default_corpus():
    return [
        ("the quick brown fox jumps over the lazy dog. the dog did not mind.", "C"),
        ("i love machine learning and transformers.", "Q"),
        ("deep learning enables summarization and translation. it is powerful.", "C"),
        ("best restaurants near me", "Q")
    ]

# -------------------------
# Entrypoint
# -------------------------
def main():
    corpus = get_default_corpus()
    print("[MAIN] Starting Bayesian Optimization over hyperparameters...")
    best = run_bayesian_optimization(corpus)
    print("[MAIN] BO Completed. Best record:")
    print(json.dumps(best, indent=2))

if __name__ == "__main__":
    main()