In [None]:
# bert_encoder_from_scratch_with_pooling_multitype_allpairs_bo_cv5.py
# Full script:
# - MLM + NSP (unchanged logic)
# - MoE Top-K routing (unchanged)
# - Bayesian Optimization (Gaussian Process + EI), 3 epochs/fold, 5-fold CV per BO eval, 10 iterations
# - Train/val/test split done once, test never used for BO
# - Final retrain on train+val with best hp, final evaluation on test set
# - Logs saved to bo_logs/

import random
import math
import os
import json
from typing import List, Tuple, Dict, Any
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, random_split, Subset
from gensim.models import Word2Vec
from datetime import datetime
from tqdm.auto import tqdm
import chromadb

# for Gaussian Process BO
from sklearn.gaussian_process import GaussianProcessRegressor
from sklearn.gaussian_process.kernels import Matern, WhiteKernel, ConstantKernel as C
from scipy.stats import norm
from sklearn.model_selection import KFold

# -------------------------
# Config (DEFAULTS)
# -------------------------
VOCAB_MIN_FREQ = 1
MAX_SEQ_LEN = 1024
HIDDEN_SIZE = 768
NUM_LAYERS = 12
NUM_HEADS = 12
FFN_DIM = 3072
DROPOUT = 0.1
WORD2VEC_SIZE = HIDDEN_SIZE
WORD2VEC_WINDOW = 5
WORD2VEC_MIN_COUNT = 1
MLM_MASK_PROB = 0.15
BATCH_SIZE = 1
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
LEARNING_RATE = 2e-5

# -------------------------
# Special tokens
# -------------------------
PAD_TOKEN = "[PAD]"
CLS_TOKEN = "[CLS]"
SEP_TOKEN = "[SEP]"
MASK_TOKEN = "[MASK]"
UNK_TOKEN = "[UNK]"
SPECIAL_TOKENS = [PAD_TOKEN, CLS_TOKEN, SEP_TOKEN, MASK_TOKEN, UNK_TOKEN]

# -------------------------
# BO Settings
# -------------------------
BO_ITERATIONS = 10
BO_INIT_POINTS = 3   # random initial points
TRIAL_EPOCHS = 1    # 3 epochs per fold as requested
K_FOLDS = 5          # strong CV chosen by you
BO_LOG_DIR = "bo_logs"

os.makedirs(BO_LOG_DIR, exist_ok=True)
os.makedirs(os.path.join(BO_LOG_DIR, "best_model"), exist_ok=True)

# -------------------------
# Utility: Vocab builder
# -------------------------
def build_vocab(sentences: List[str], min_freq: int = VOCAB_MIN_FREQ):
    from collections import Counter
    token_counts = Counter()
    for s in sentences:
        tokens = s.strip().split()
        token_counts.update(tokens)
    stoi, itos = {}, []
    for t in SPECIAL_TOKENS:
        stoi[t] = len(itos)
        itos.append(t)
    for token, cnt in token_counts.items():
        if cnt >= min_freq and token not in stoi:
            stoi[token] = len(itos)
            itos.append(token)
    return stoi, itos

# -------------------------
# Train or load Word2Vec
# -------------------------
def train_word2vec(sentences: List[str], vector_size=WORD2VEC_SIZE, window=WORD2VEC_WINDOW, min_count=WORD2VEC_MIN_COUNT, epochs=5):
    tokenized = [s.strip().split() for s in sentences]
    w2v = Word2Vec(sentences=tokenized, vector_size=vector_size, window=window, min_count=min_count, epochs=epochs, sg=0)
    return w2v

def build_embedding_matrix(w2v: Word2Vec, itos: List[str], hidden_size: int):
    vocab_size = len(itos)
    embeddings = np.random.normal(scale=0.02, size=(vocab_size, hidden_size)).astype(np.float32)
    for idx, tok in enumerate(itos):
        if tok in w2v.wv:
            vec = w2v.wv[tok]
            if vec.shape[0] != hidden_size:
                vec = vec[:hidden_size] if vec.shape[0] >= hidden_size else np.pad(vec, (0, hidden_size - vec.shape[0]))
            embeddings[idx] = vec
    pad_idx = itos.index(PAD_TOKEN)
    embeddings[pad_idx] = np.zeros(hidden_size, dtype=np.float32)
    return torch.tensor(embeddings)

# -------------------------
# Dataset (supports queries and chunks)
# -------------------------
class BertPretrainingDataset(Dataset):
    def __init__(self, data: List[Tuple[str, str]], stoi: dict, max_seq_len=MAX_SEQ_LEN):
        """
        data: list of tuples [(text, discriminator)], where discriminator ∈ {'Q', 'C'}
        """
        self.stoi = stoi
        self.max_seq_len = max_seq_len
        self.samples = []

        for text, dtype in data:
            if dtype == "Q":
                # Single-sentence query (MLM only)
                self.samples.append((text, dtype, None, None))
            elif dtype == "C":
                # Split chunk into sentences
                sents = [s.strip() for s in text.strip().split('.') if s.strip()]
                if len(sents) < 2:
                    sents = sents + sents  # duplicate if only one sentence
                
                # Debug: Print sentence count for first few chunks
                if len(self.samples) < 10:
                    print(f"[DEBUG] Chunk has {len(sents)} sentences, will generate ~{len(sents)-1 + len(sents)*(len(sents)-1)//2} pairs")
                
                # Positive pairs: consecutive sentences
                for i in range(len(sents) - 1):
                    self.samples.append((sents[i], "C", sents[i + 1], 1))
                # Negative pairs: non-consecutive
                for i in range(len(sents)):
                    for j in range(len(sents)):
                        if abs(i - j) > 1:  # skip consecutive
                            self.samples.append((sents[i], "C", sents[j], 0))

    def __len__(self):
        return len(self.samples)

    def _tokenize_to_ids(self, text: str) -> List[int]:
        toks = text.strip().split()
        return [self.stoi.get(t, self.stoi[UNK_TOKEN]) for t in toks]

    def __getitem__(self, idx):
        sent_a, dtype, sent_b, nsp_label = self.samples[idx]

        # -------------------------------
        # Case 1: Query (MLM only)
        # -------------------------------
        if dtype == 'Q':
            ids = self._tokenize_to_ids(sent_a)
            ids = ids[:self.max_seq_len - 2]
            input_ids = [self.stoi[CLS_TOKEN]] + ids + [self.stoi[SEP_TOKEN]]
            token_type_ids = [0] * len(input_ids)
            nsp_label = -100  # dummy
            return {
                "input_ids": torch.tensor(input_ids, dtype=torch.long),
                "token_type_ids": torch.tensor(token_type_ids, dtype=torch.long),
                "nsp_label": torch.tensor(nsp_label, dtype=torch.long),
                "batch_type": "Q"
            }

        # -------------------------------
        # Case 2: Chunk (MLM + NSP)
        # -------------------------------
        elif dtype == 'C':
            ids_a = self._tokenize_to_ids(sent_a)
            ids_b = self._tokenize_to_ids(sent_b)
            while len(ids_a) + len(ids_b) > self.max_seq_len - 3:
                if len(ids_a) > len(ids_b):
                    ids_a.pop()
                else:
                    ids_b.pop()
            input_ids = [self.stoi[CLS_TOKEN]] + ids_a + [self.stoi[SEP_TOKEN]] + ids_b + [self.stoi[SEP_TOKEN]]
            token_type_ids = [0] * (len(ids_a) + 2) + [1] * (len(ids_b) + 1)
            return {
                "input_ids": torch.tensor(input_ids, dtype=torch.long),
                "token_type_ids": torch.tensor(token_type_ids, dtype=torch.long),
                "nsp_label": torch.tensor(nsp_label, dtype=torch.long),
                "batch_type": "C"
            }

def collate_fn(batch, pad_id):
    input_ids_list = [b["input_ids"] for b in batch]
    token_type_list = [b["token_type_ids"] for b in batch]
    nsp_labels = torch.stack([b["nsp_label"] for b in batch]).long()
    batch_types = [b["batch_type"] for b in batch]

    max_len = max([x.size(0) for x in input_ids_list])
    padded_input_ids, padded_token_types, attention_masks = [], [], []
    for ids, tt in zip(input_ids_list, token_type_list):
        pad_len = max_len - ids.size(0)
        padded_input_ids.append(F.pad(ids, (0, pad_len), value=pad_id))
        padded_token_types.append(F.pad(tt, (0, pad_len), value=0))
        attention_masks.append((F.pad(ids, (0, pad_len), value=pad_id) != pad_id).long())

    return {
        "input_ids": torch.stack(padded_input_ids),
        "token_type_ids": torch.stack(padded_token_types),
        "attention_mask": torch.stack(attention_masks),
        "nsp_labels": nsp_labels,
        "batch_type": batch_types
    }

# -------------------------
# MLM Masking
# -------------------------
def create_mlm_labels_and_masked_input(input_ids, pad_id, mask_token_id, vocab_size, device=None, mask_prob=MLM_MASK_PROB):
    if device is None:
        device = input_ids.device
    batch_size, seq_len = input_ids.shape
    mlm_labels = torch.full_like(input_ids, -100)
    prob_matrix = torch.full((batch_size, seq_len), mask_prob, device=device)
    prob_matrix[input_ids == pad_id] = 0.0
    special_upper = len(SPECIAL_TOKENS)
    prob_matrix[input_ids < special_upper] = 0.0
    
    # Ensure at least some tokens can be masked
    maskable_positions = (input_ids >= special_upper) & (input_ids != pad_id)
    if maskable_positions.sum() == 0:
        # No maskable tokens, return unmasked input
        return input_ids.clone(), mlm_labels
    
    masked_positions = torch.bernoulli(prob_matrix).bool()
    
    # If no positions were masked, force at least one
    if masked_positions.sum() == 0 and maskable_positions.sum() > 0:
        # Randomly select one maskable position per sequence
        for b in range(batch_size):
            maskable_in_seq = maskable_positions[b].nonzero(as_tuple=True)[0]
            if len(maskable_in_seq) > 0:
                chosen_idx = maskable_in_seq[torch.randint(len(maskable_in_seq), (1,), device=device)]
                masked_positions[b, chosen_idx] = True
    
    mlm_labels[masked_positions] = input_ids[masked_positions]
    input_ids_masked = input_ids.clone()
    rand_for_replace = torch.rand_like(input_ids, dtype=torch.float)
    mask_replace = masked_positions & (rand_for_replace < 0.8)
    random_replace = masked_positions & (rand_for_replace >= 0.8) & (rand_for_replace < 0.9)
    input_ids_masked[mask_replace] = mask_token_id
    if random_replace.any():
        count = int(random_replace.sum().item())
        rand_tokens = torch.randint(len(SPECIAL_TOKENS), vocab_size, (count,), device=device)
        input_ids_masked[random_replace] = rand_tokens
    return input_ids_masked, mlm_labels

# -------------------------
# Mixture-of-Experts Module (unchanged logic)
# -------------------------
class MoE(nn.Module):
    def __init__(self, hidden_size, ffn_dim, num_experts=5, k=2, noise_std=1.0):
        super().__init__()
        self.hidden_size = hidden_size
        self.ffn_dim = ffn_dim
        self.num_experts = num_experts
        self.k = k
        self.noise_std = noise_std

        # experts: each expert is a small Feed-Forward Network (H -> ffn_dim -> H)
        self.experts = nn.ModuleList([
            nn.Sequential(
                nn.Linear(hidden_size, ffn_dim),
                nn.GELU(),
                nn.Linear(ffn_dim, hidden_size)
            ) for _ in range(num_experts)
        ])

        # router: maps hidden vector to expert logits
        self.router = nn.Linear(hidden_size, num_experts)

    def forward(self, x, mask=None):
        """
        x: (B, S, H)
        returns: out (B, S, H), aux_loss (scalar)
        """
        B, S, H = x.size()
        # ---- router logits (noiseless, for load-balancing) ----
        logits = self.router(x)  # (B, S, E)
        # soft probabilities for load balancing (use non-noisy softmax)
        probs_all = F.softmax(logits, dim=-1)  # (B, S, E)
        # importance per expert:
        importance = probs_all.sum(dim=(0, 1))  # (E,)
        total_tokens = float(B * S)
        # aux_loss encourages balanced importance across experts
        aux_loss = (self.num_experts * (importance / total_tokens).pow(2).sum())

        # ---- noisy logits for selection (only add noise during training) ----
        if self.training:
            noise = torch.randn_like(logits) * self.noise_std
            logits_noisy = logits + noise
        else:
            logits_noisy = logits

        # top-k selection on noisy logits
        topk_vals, topk_idx = torch.topk(logits_noisy, self.k, dim=-1)  # shapes (B,S,k)
        # convert topk vals to normalized weights via softmax over k
        topk_weights = F.softmax(topk_vals, dim=-1)  # (B,S,k)

        # Compute each expert's output on the full x (inefficient but simple)
        expert_outs = []
        for e in range(self.num_experts):
            expert_outs.append(self.experts[e](x))  # (B,S,H)
        expert_stack = torch.stack(expert_outs, dim=2)  # (B,S,E,H)

        # Build a gating tensor of shape (B,S,E) with nonzero entries only at topk indices
        device = x.device
        gating = torch.zeros(B, S, self.num_experts, device=device, dtype=x.dtype)  # float
        # scatter the topk_weights into gating at positions topk_idx
        # topk_idx: (B,S,k), topk_weights: (B,S,k)
        # We can flatten and scatter
        flat_idx = topk_idx.view(-1, self.k)  # (B*S, k)
        flat_w = topk_weights.view(-1, self.k)  # (B*S, k)
        # For each row r in [0..B*S-1], scatter into gating_flat[r, idx] = weight
        gating_flat = gating.view(-1, self.num_experts)  # (B*S, E)
        rows = torch.arange(gating_flat.size(0), device=device).unsqueeze(1).expand(-1, self.k)  # (B*S, k)
        gating_flat.scatter_(1, flat_idx, flat_w)
        gating = gating_flat.view(B, S, self.num_experts)  # (B,S,E)

        # Combine experts: out[b,s,:] = sum_e gating[b,s,e] * expert_stack[b,s,e,:]
        out = torch.einsum('bse,bseh->bsh', gating, expert_stack)  # (B,S,H)

        return out, aux_loss

# -------------------------
# Transformer encoder (unchanged logic)
# -------------------------
class TransformerEncoderLayer(nn.Module):
    def __init__(self, hidden_size, num_heads, ffn_dim, dropout=0.1, moe_experts=5, moe_k=2):
        super().__init__()
        self.self_attn = nn.MultiheadAttention(hidden_size, num_heads, dropout=dropout, batch_first=True)
        self.ln1 = nn.LayerNorm(hidden_size)
        self.ln2 = nn.LayerNorm(hidden_size)
        # Replace ffn with MoE module
        self.ffn_moe = MoE(hidden_size, ffn_dim, num_experts=moe_experts, k=moe_k)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask):
        key_padding_mask = (mask == 0)
        attn_out, _ = self.self_attn(x, x, x, key_padding_mask=key_padding_mask)
        x = self.ln1(x + self.dropout(attn_out))
        # MoE FFN
        ffn_out, aux_loss = self.ffn_moe(x, mask)
        x = self.ln2(x + self.dropout(ffn_out))
        return x, aux_loss

class BertEncoderModel(nn.Module):
    def __init__(self, vocab_size, hidden_size=HIDDEN_SIZE, num_layers=NUM_LAYERS, num_heads=NUM_HEADS, ffn_dim=FFN_DIM, max_position_embeddings=512, pad_token_id=0, embedding_weights=None, moe_experts=5, moe_k=2):
        super().__init__()
        self.pad_token_id = pad_token_id
        self.token_embeddings = nn.Embedding(vocab_size, hidden_size, padding_idx=pad_token_id)
        if embedding_weights is not None:
            self.token_embeddings.weight.data.copy_(embedding_weights)
        self.position_embeddings = nn.Embedding(max_position_embeddings, hidden_size)
        self.segment_embeddings = nn.Embedding(2, hidden_size)
        self.emb_ln = nn.LayerNorm(hidden_size)
        self.emb_dropout = nn.Dropout(0.1)
        self.layers = nn.ModuleList([TransformerEncoderLayer(hidden_size, num_heads, ffn_dim, dropout=DROPOUT, moe_experts=moe_experts, moe_k=moe_k) for _ in range(num_layers)])
        self.nsp_classifier = nn.Sequential(nn.Linear(hidden_size, hidden_size), nn.Tanh(), nn.Linear(hidden_size, 2))
        self.mlm_bias = nn.Parameter(torch.zeros(vocab_size))

    def encode(self, ids, tt=None, mask=None):
        if tt is None:
            tt = torch.zeros_like(ids)
        if mask is None:
            mask = (ids != self.pad_token_id).long()
        pos = torch.arange(ids.size(1), device=ids.device).unsqueeze(0)
        x = self.token_embeddings(ids) + self.position_embeddings(pos) + self.segment_embeddings(tt)
        x = self.emb_dropout(self.emb_ln(x))
        total_aux = 0.0
        for layer in self.layers:
            x, aux = layer(x, mask)
            total_aux = total_aux + aux
        return x, total_aux
    def forward(self, ids, tt=None, mask=None):
        seq_out, total_aux = self.encode(ids, tt, mask)
        pooled = seq_out[:, 0]
        nsp_logits = self.nsp_classifier(pooled)
        mlm_logits = F.linear(seq_out, self.token_embeddings.weight, self.mlm_bias)
        return mlm_logits, nsp_logits, total_aux

# -------------------------
# Training / evaluation helpers (refactors of your loop)
# -------------------------
def train_one_epoch(model: nn.Module, dataloader: DataLoader, optimizer: torch.optim.Optimizer,
                    mlm_loss_fct, nsp_loss_fct, pad_id: int, mask_id: int, vocab_size: int, device: torch.device,
                    aux_coeff: float):
    model.train()
    total_loss = 0.0
    total_steps = 0
    pbar = tqdm(dataloader, desc="Training", leave=False)
    for batch in pbar:
        ids = batch["input_ids"].to(device)
        tts = batch["token_type_ids"].to(device)
        mask = batch["attention_mask"].to(device)
        nsp_labels = batch["nsp_labels"].to(device)
        btypes = batch["batch_type"]
        ids_masked, mlm_labels = create_mlm_labels_and_masked_input(ids, pad_id, mask_id, vocab_size, device=device)

        optimizer.zero_grad()
        mlm_logits, nsp_logits, aux_loss = model(ids_masked, tts, mask)
        mlm_loss = mlm_loss_fct(mlm_logits.view(-1, mlm_logits.size(-1)), mlm_labels.view(-1))
        if all(bt == "C" for bt in btypes):
            nsp_loss = nsp_loss_fct(nsp_logits.view(-1, 2), nsp_labels.view(-1))
        else:
            nsp_loss = torch.tensor(0.0, device=device)

        loss = mlm_loss + nsp_loss + aux_coeff * aux_loss
        
        # Check for NaN loss before backprop
        if torch.isnan(loss) or torch.isinf(loss):
            print(f"[WARNING] NaN/Inf loss detected, skipping batch")
            continue
            
        loss.backward()
        
        # Gradient clipping to prevent explosion
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        
        optimizer.step()

        total_loss += loss.item()
        total_steps += 1
        pbar.set_postfix({"loss": f"{loss.item():.4f}"})

    if total_steps == 0:
        return float('inf')  # Return inf instead of NaN for empty dataloader
    avg_loss = total_loss / total_steps
    return avg_loss

def evaluate_model(model: nn.Module, dataloader: DataLoader, pad_id: int, mask_id: int, vocab_size: int, device: torch.device):
    """
    Evaluates total loss (MLM + NSP + aux) averaged across dataloader.
    Returns avg_loss (float).
    """
    model.eval()
    total_loss = 0.0
    total_steps = 0
    mlm_loss_fct = nn.CrossEntropyLoss(ignore_index=-100)
    nsp_loss_fct = nn.CrossEntropyLoss()
    aux_coeff = 0.01
    with torch.no_grad():
        pbar = tqdm(dataloader, desc="Evaluating", leave=False)
        for batch in pbar:
            ids = batch["input_ids"].to(device)
            tts = batch["token_type_ids"].to(device)
            mask = batch["attention_mask"].to(device)
            nsp_labels = batch["nsp_labels"].to(device)
            btypes = batch["batch_type"]
            ids_masked, mlm_labels = create_mlm_labels_and_masked_input(ids, pad_id, mask_id, vocab_size, device=device)

            mlm_logits, nsp_logits, aux_loss = model(ids_masked, tts, mask)
            mlm_loss = mlm_loss_fct(mlm_logits.view(-1, mlm_logits.size(-1)), mlm_labels.view(-1))
            if all(bt == "C" for bt in btypes):
                nsp_loss = nsp_loss_fct(nsp_logits.view(-1, 2), nsp_labels.view(-1))
            else:
                nsp_loss = torch.tensor(0.0, device=device)

            loss = mlm_loss + nsp_loss + aux_coeff * aux_loss
            total_loss += loss.item()
            total_steps += 1
            pbar.set_postfix({"loss": f"{loss.item():.4f}"})

    if total_steps == 0:
        return float('inf')  # Return inf instead of NaN for empty dataloader
    avg_loss = total_loss / total_steps
    return avg_loss

def compute_metrics(model: nn.Module, dataloader: DataLoader, pad_id: int, mask_id: int, vocab_size: int, device: torch.device):
    """
    Compute MLM accuracy, NSP accuracy and average total loss on dataloader.
    Returns: dict with keys: avg_loss, mlm_acc, nsp_acc
    """
    model.eval()
    mlm_loss_fct = nn.CrossEntropyLoss(ignore_index=-100)
    nsp_loss_fct = nn.CrossEntropyLoss()
    aux_coeff = 0.01

    total_loss = 0.0
    total_steps = 0

    total_mlm_correct = 0
    total_mlm_count = 0
    total_nsp_correct = 0
    total_nsp_count = 0

    with torch.no_grad():
        pbar = tqdm(dataloader, desc="Computing Metrics", leave=False)
        for batch in pbar:
            ids = batch["input_ids"].to(device)
            tts = batch["token_type_ids"].to(device)
            mask = batch["attention_mask"].to(device)
            nsp_labels = batch["nsp_labels"].to(device)
            btypes = batch["batch_type"]
            ids_masked, mlm_labels = create_mlm_labels_and_masked_input(ids, pad_id, mask_id, vocab_size, device=device)

            mlm_logits, nsp_logits, aux_loss = model(ids_masked, tts, mask)
            mlm_loss = mlm_loss_fct(mlm_logits.view(-1, mlm_logits.size(-1)), mlm_labels.view(-1))
            if all(bt == "C" for bt in btypes):
                nsp_loss = nsp_loss_fct(nsp_logits.view(-1, 2), nsp_labels.view(-1))
            else:
                nsp_loss = torch.tensor(0.0, device=device)

            loss = mlm_loss + nsp_loss + aux_coeff * aux_loss
            total_loss += loss.item()
            total_steps += 1

            # MLM accuracy
            mlm_preds = mlm_logits.argmax(-1)
            mask_positions = mlm_labels != -100
            if mask_positions.sum().item() > 0:
                total_mlm_correct += (mlm_preds[mask_positions] == mlm_labels[mask_positions]).sum().item()
                total_mlm_count += mask_positions.sum().item()

            # NSP accuracy
            if all(bt == "C" for bt in btypes):
                nsp_preds = nsp_logits.argmax(-1)
                total_nsp_correct += (nsp_preds == nsp_labels).sum().item()
                total_nsp_count += nsp_labels.numel()

    avg_loss = total_loss / max(1, total_steps)
    mlm_acc = total_mlm_correct / max(1, total_mlm_count) if total_mlm_count > 0 else 0.0
    nsp_acc = total_nsp_correct / max(1, total_nsp_count) if total_nsp_count > 0 else 0.0

    return {"avg_loss": avg_loss, "mlm_acc": mlm_acc, "nsp_acc": nsp_acc}

# -------------------------
# Helper: build data, vocab, embeddings given hyperparams
# -------------------------
def prepare_data_and_model_artifacts(corpus: List[Tuple[str, str]], hyperparams: Dict[str, Any] = None):
    """
    Build vocab, embeddings and dataset from corpus. This does NOT split train/val/test.
    Returns: stoi, itos, vocab_size, emb (tensor), pad_id, mask_id, ds (dataset)
    hyperparams may alter word2vec_window, WORD2VEC_SIZE, WORD2VEC_MIN_COUNT, MAX_SEQ_LEN, etc.
    """
    hv = hyperparams or {}
    word2vec_window = int(hv.get("word2vec_window", WORD2VEC_WINDOW))
    word2vec_size = int(hv.get("word2vec_size", WORD2VEC_SIZE))
    word2vec_min_count = int(hv.get("word2vec_min_count", WORD2VEC_MIN_COUNT))
    max_seq_len = int(hv.get("max_seq_len", MAX_SEQ_LEN))

    texts = [x[0] for x in corpus]
    stoi, itos = build_vocab(texts, min_freq=VOCAB_MIN_FREQ)
    vocab_size = len(itos)

    w2v = train_word2vec(texts, vector_size=word2vec_size, window=word2vec_window, min_count=word2vec_min_count, epochs=5)
    emb = build_embedding_matrix(w2v, itos, HIDDEN_SIZE)
    pad_id = stoi[PAD_TOKEN]; mask_id = stoi[MASK_TOKEN]
    ds = BertPretrainingDataset(corpus, stoi, max_seq_len=max_seq_len)
    print(f"[DEBUG] Corpus size: {len(corpus)} documents")
    print(f"[DEBUG] Dataset size (after pair generation): {len(ds)} samples")
    print(f"[DEBUG] Vocab size: {vocab_size}")
    return stoi, itos, vocab_size, emb, pad_id, mask_id, ds

# -------------------------
# BO helpers: search space sampling, GP + EI
# -------------------------
def sample_random_hyperparams():
    """
    Sample hyperparameters in the agreed search space.
    """
    sample = {}
    # learning_rate ∈ [1e-6, 1e-5] (log-uniform) - reduced upper bound to prevent instability
    lr = 10 ** np.random.uniform(np.log10(1e-6), np.log10(1e-5))
    sample["learning_rate"] = float(lr)

    # moe_experts ∈ {3,4,5,6}
    sample["moe_experts"] = int(np.random.choice([3,4,5,6]))

    # moe_k ∈ {1,2}
    sample["moe_k"] = int(np.random.choice([1,2]))

    # ffn_dim ∈ [1024,4096] (discrete set multiples of 256)
    sample["ffn_dim"] = int(int(np.random.choice(np.arange(1024, 4097, 256))))

    # num_layers ∈ {6,8,12}
    sample["num_layers"] = int(np.random.choice([6,8,12]))

    # num_heads ∈ {6,8,12}
    sample["num_heads"] = int(np.random.choice([6,8,12]))

    # word2vec_window ∈ {3,5,7}
    sample["word2vec_window"] = int(np.random.choice([3,5,7]))

    # mlm_mask_prob ∈ [0.10,0.20]
    sample["mlm_mask_prob"] = float(np.random.uniform(0.10, 0.20))

    # batch_size choices
    sample["batch_size"] = int(np.random.choice([4, 8, 16]))

    return sample

def hyperparams_to_vector(hp: Dict[str,Any]) -> np.ndarray:
    """
    Convert hyperparameter dict to numeric vector for GP.
    We'll use a consistent ordering:
    [log_lr, moe_experts, moe_k, ffn_dim/256, num_layers, num_heads, word2vec_window, mlm_mask_prob, batch_size]
    """
    vec = []
    vec.append(np.log10(hp["learning_rate"]))
    vec.append(float(hp["moe_experts"]))
    vec.append(float(hp["moe_k"]))
    vec.append(float(hp["ffn_dim"]) / 256.0)
    vec.append(float(hp["num_layers"]))
    vec.append(float(hp["num_heads"]))
    vec.append(float(hp["word2vec_window"]))
    vec.append(float(hp["mlm_mask_prob"]))
    vec.append(float(hp.get("batch_size", BATCH_SIZE)))
    return np.array(vec, dtype=float)

def expected_improvement(mu: np.ndarray, sigma: np.ndarray, best: float, xi: float = 0.01):
    """
    Expected Improvement (for minimization: improvement = best - f).
    mu, sigma: arrays of shape (n_candidates,)
    best: current best (lower is better)
    We compute EI = E[max(best - f, 0)] where f ~ N(mu, sigma^2)
    """
    sigma = np.maximum(sigma, 1e-9)
    imp = best - mu - xi
    Z = imp / sigma
    ei = imp * norm.cdf(Z) + sigma * norm.pdf(Z)
    ei[sigma == 0.0] = 0.0
    return ei

# -------------------------
# Objective function that runs K-fold CV on combined train+val dataset
# -------------------------
def objective_function_with_kfold(hp: Dict[str,Any],
                                  vocab_size: int,
                                  emb: torch.Tensor,
                                  pad_id: int,
                                  mask_id: int,
                                  combined_indices: List[int],
                                  ds: Dataset,
                                  k_folds: int,
                                  run_id: int,
                                  verbose: bool = True) -> float:
    """
    Given hyperparameters and the combined indices (train+val indices referring to original dataset ds),
    perform k-fold CV on that combined set. For each fold:
      - Create model from scratch
      - Train for TRIAL_EPOCHS on training folds
      - Evaluate on validation fold (compute avg total loss)
    Return the average validation loss across folds.

    combined_indices: list of indices into ds to be used for CV
    ds: original dataset
    """
    # Adjust k_folds if dataset is too small
    n_samples = len(combined_indices)
    actual_k_folds = min(k_folds, n_samples)
    if actual_k_folds < k_folds:
        if verbose:
            print(f"[Run {run_id}] Warning: Reducing k_folds from {k_folds} to {actual_k_folds} due to small dataset size ({n_samples} samples)")
    
    # Set up KFold (shuffle for randomness but fixed random_state for reproducibility)
    kf = KFold(n_splits=actual_k_folds, shuffle=True, random_state=42)

    fold_losses = []
    fold_details = []

    # For each fold, create train and val Subset objects
    combined_array = np.array(combined_indices)
    fold_pbar = tqdm(enumerate(kf.split(combined_array)), total=actual_k_folds, desc=f"[Run {run_id}] K-Fold CV", leave=False)
    for fold_idx, (train_pos, val_pos) in fold_pbar:
        # Map positions to original dataset indices
        train_indices = combined_array[train_pos].tolist()
        val_indices = combined_array[val_pos].tolist()

        train_subset = Subset(ds, train_indices)
        val_subset = Subset(ds, val_indices)

        # dataloaders
        batch_size = int(hp.get("batch_size", BATCH_SIZE))
        dl_train = DataLoader(train_subset, batch_size=batch_size, shuffle=True, collate_fn=lambda b: collate_fn(b, pad_id))
        dl_val = DataLoader(val_subset, batch_size=batch_size, shuffle=False, collate_fn=lambda b: collate_fn(b, pad_id))

        # Build model afresh with hyperparameters
        hidden_size = HIDDEN_SIZE
        model = BertEncoderModel(
            vocab_size=vocab_size,
            hidden_size=hidden_size,
            num_layers=int(hp["num_layers"]),
            num_heads=int(hp["num_heads"]),
            ffn_dim=int(hp["ffn_dim"]),
            max_position_embeddings=MAX_SEQ_LEN,
            pad_token_id=pad_id,
            embedding_weights=emb,
            moe_experts=int(hp["moe_experts"]),
            moe_k=int(hp["moe_k"])
        ).to(DEVICE)

        mlm_loss_fct = nn.CrossEntropyLoss(ignore_index=-100)
        nsp_loss_fct = nn.CrossEntropyLoss()
        lr = float(hp["learning_rate"])
        optimizer = torch.optim.AdamW(model.parameters(), lr=lr)

        # override global MLM mask prob temporaily
        global MLM_MASK_PROB
        prev_mlm_mask_prob = MLM_MASK_PROB
        MLM_MASK_PROB = float(hp.get("mlm_mask_prob", MLM_MASK_PROB))
        aux_coeff = 0.01

        try:
            epoch_pbar = tqdm(range(TRIAL_EPOCHS), desc=f"Fold {fold_idx+1}/{actual_k_folds} Epochs", leave=False)
            for epoch in epoch_pbar:
                _train_loss = train_one_epoch(model, dl_train, optimizer, mlm_loss_fct, nsp_loss_fct,
                                              pad_id, mask_id, vocab_size, DEVICE, aux_coeff)
                epoch_pbar.set_postfix({"train_loss": f"{_train_loss:.6f}"})
                if verbose:
                    print(f"[Run {run_id}] Fold {fold_idx+1}/{actual_k_folds} Epoch {epoch+1}/{TRIAL_EPOCHS} - train_loss: {_train_loss:.6f}")

            # Evaluate on the fold's validation data
            val_loss = evaluate_model(model, dl_val, pad_id, mask_id, vocab_size, DEVICE)
            # Validate loss is finite
            if not np.isfinite(val_loss):
                if verbose:
                    print(f"[Run {run_id}] Warning: Fold {fold_idx+1} returned non-finite loss, using large penalty value")
                val_loss = 1e9
            fold_losses.append(val_loss)
            fold_details.append({"fold": fold_idx+1, "val_loss": float(val_loss), "train_size": len(train_subset), "val_size": len(val_subset)})
            fold_pbar.set_postfix({"val_loss": f"{val_loss:.6f}"})
            if verbose:
                print(f"[Run {run_id}] Fold {fold_idx+1}/{actual_k_folds} - val_loss: {val_loss:.6f}")
        finally:
            MLM_MASK_PROB = prev_mlm_mask_prob
            torch.cuda.empty_cache()

    # average val loss over folds
    avg_val_loss = float(np.mean(fold_losses)) if fold_losses else float(1e9)
    # Final check for NaN
    if not np.isfinite(avg_val_loss):
        avg_val_loss = 1e9
    # for logging return also per-fold details if needed
    if verbose:
        print(f"[Run {run_id}] K-Fold CV average val loss (k={actual_k_folds}): {avg_val_loss:.6f}")
    return avg_val_loss, fold_details

# -------------------------
# Main: Bayesian Optimization loop (does a single dataset split before BO)
# -------------------------
def run_bayesian_optimization_with_heldout_test_cv5(corpus: List[Tuple[str,str]]):
    """
    - Prepare artifacts (vocab, emb, dataset)
    - Split once into train/val/test (test held-out and never seen by BO)
    - Run BO using train+val combined and K-fold CV on that combined set
    - Retrain final model on train+val
    - Evaluate final model on test and save metrics + model
    """
    # Prepare artifacts (vocab, embeddings, dataset)
    stoi, itos, vocab_size, emb, pad_id, mask_id, ds = prepare_data_and_model_artifacts(corpus)

    # Split into Train/Val/Test once. Use 70/15/15 split (or fallback when small)
    total_len = len(ds)
    if total_len < 5:
        # tiny dataset: split into train= max(1, total-2), val=1, test=1 where possible
        test_len = 1
        val_len = 1 if total_len - test_len - 1 > 0 else 0
        train_len = total_len - val_len - test_len
    else:
        test_len = max(1, total_len // 10)  # ~10%
        val_len = max(1, total_len // 10)
        train_len = total_len - val_len - test_len
    # Ensure non-negative
    if train_len < 1:
        train_len = max(1, total_len - val_len - test_len)
    lengths = [train_len, val_len, test_len]
    train_ds, val_ds, test_ds = random_split(ds, lengths)

    print(f"[DATA SPLIT] total={total_len}, train={len(train_ds)}, val={len(val_ds)}, test={len(test_ds)}")

    # Combine train+val indices to be used for K-fold CV (BO uses only these)
    # random_split returns Subset objects with .indices attribute referencing original dataset
    train_indices = train_ds.indices if hasattr(train_ds, "indices") else list(range(len(train_ds)))
    val_indices = val_ds.indices if hasattr(val_ds, "indices") else list(range(len(val_ds)))
    combined_indices = list(train_indices) + list(val_indices)
    random.shuffle(combined_indices)  # shuffle combined pool before KFold splitting

    print(f"[BO] Using combined train+val pool of size {len(combined_indices)} for {K_FOLDS}-fold CV in BO.")

    # BO bookkeeping
    bo_records = []
    X = []
    y = []

    # initial random evaluations
    print("[BO] Starting initial random evaluations (random seed sampling)...")
    init_pbar = tqdm(range(BO_INIT_POINTS), desc="BO Initial Samples")
    for i in init_pbar:
        hp = sample_random_hyperparams()
        # ensure fields exist for logging
        hp["batch_size"] = int(hp.get("batch_size", BATCH_SIZE))
        hp["word2vec_size"] = WORD2VEC_SIZE
        hp["word2vec_min_count"] = WORD2VEC_MIN_COUNT
        # ensure mandatory hyperparameters present (sample function sets them)
        print(f"[BO] Init sample {i+1}/{BO_INIT_POINTS}: {hp}")
        val_loss, fold_details = objective_function_with_kfold(hp, vocab_size, emb, pad_id, mask_id, combined_indices, ds, K_FOLDS, run_id=i+1, verbose=True)
        vec = hyperparams_to_vector(hp)
        X.append(vec)
        y.append(val_loss)
        rec = {"iteration": i+1, "params": hp, "loss": float(val_loss), "folds": fold_details}
        bo_records.append(rec)
        init_pbar.set_postfix({"val_loss": f"{val_loss:.6f}"})
        # save intermediate logs
        with open(os.path.join(BO_LOG_DIR, "bo_results.json"), "w") as f:
            json.dump(bo_records, f, indent=2)

    # GP surrogate
    kernel = C(1.0, (1e-3, 1e3)) * Matern(length_scale=np.ones(X[0].shape[0]), nu=2.5) + WhiteKernel(noise_level=1e-6)
    gp = GaussianProcessRegressor(kernel=kernel, alpha=1e-6, normalize_y=True, n_restarts_optimizer=3, random_state=42)

    # main BO iterations
    bo_pbar = tqdm(range(BO_INIT_POINTS, BO_ITERATIONS), desc="BO Iterations")
    for it in bo_pbar:
        print(f"[BO] Iteration {it+1}/{BO_ITERATIONS}")

        # fit GP on current data
        X_arr = np.vstack(X)
        y_arr = np.array(y)
        gp.fit(X_arr, y_arr)

        # propose many random candidates and pick one that maximizes EI
        n_candidates = 200
        candidates = []
        cand_vecs = []
        for _ in range(n_candidates):
            c = sample_random_hyperparams()
            c["batch_size"] = int(c.get("batch_size", BATCH_SIZE))
            c["word2vec_size"] = WORD2VEC_SIZE
            c["word2vec_min_count"] = WORD2VEC_MIN_COUNT
            candidates.append(c)
            cand_vecs.append(hyperparams_to_vector(c))
        cand_vecs = np.vstack(cand_vecs)

        mu, sigma = gp.predict(cand_vecs, return_std=True)
        best_so_far = np.min(y_arr)
        eis = expected_improvement(mu, sigma, best_so_far, xi=0.01)
        best_idx = int(np.argmax(eis))
        chosen_hp = candidates[best_idx]
        print(f"[BO] Chosen hyperparams (EI max): {chosen_hp}")

        # evaluate chosen configuration via K-fold CV
        val_loss, fold_details = objective_function_with_kfold(chosen_hp, vocab_size, emb, pad_id, mask_id, combined_indices, ds, K_FOLDS, run_id=it+1, verbose=True)
        vec = hyperparams_to_vector(chosen_hp)
        X.append(vec)
        y.append(val_loss)

        rec = {"iteration": it+1, "params": chosen_hp, "loss": float(val_loss), "folds": fold_details}
        bo_pbar.set_postfix({"val_loss": f"{val_loss:.6f}", "best": f"{np.min(y):.6f}"})
        # save logs
        with open(os.path.join(BO_LOG_DIR, "bo_results.json"), "w") as f:
            json.dump(bo_records, f, indent=2)

    # compute best config
    losses = [r["loss"] for r in bo_records]
    best_idx = int(np.argmin(losses))
    best_record = bo_records[best_idx]
    best_hp = best_record["params"]
    best_loss = best_record["loss"]
    print(f"[BO] Best config found: {best_hp} with loss {best_loss:.6f}")

    # save best config to file
    with open(os.path.join(BO_LOG_DIR, "best_config.json"), "w") as f:
        json.dump({"best_params": best_hp, "best_loss": best_loss}, f, indent=2)

    # Retrain final model on train+val with best hyperparams
    print("[BO] Training final model on train+val with best hyperparameters...")
    # Combine train + val indices into a single dataset (Subset)
    # random_split returned Subset objects; their indices refer to ds
    train_indices = train_ds.indices if hasattr(train_ds, "indices") else list(range(len(train_ds)))
    val_indices = val_ds.indices if hasattr(val_ds, "indices") else list(range(len(val_ds)))
    combined_trainval_indices = list(train_indices) + list(val_indices)
    dl_full = DataLoader(Subset(ds, combined_trainval_indices), batch_size=int(best_hp.get("batch_size", BATCH_SIZE)), shuffle=True, collate_fn=lambda b: collate_fn(b, pad_id))

    final_model = BertEncoderModel(
        vocab_size=vocab_size,
        hidden_size=HIDDEN_SIZE,
        num_layers=int(best_hp["num_layers"]),
        num_heads=int(best_hp["num_heads"]),
        ffn_dim=int(best_hp["ffn_dim"]),
        max_position_embeddings=MAX_SEQ_LEN,
        pad_token_id=pad_id,
        embedding_weights=emb,
        moe_experts=int(best_hp["moe_experts"]),
        moe_k=int(best_hp["moe_k"])
    ).to(DEVICE)

    # optimizer
    lr = float(best_hp["learning_rate"])
    optimizer = torch.optim.AdamW(final_model.parameters(), lr=lr)
    mlm_loss_fct = nn.CrossEntropyLoss(ignore_index=-100)
    nsp_loss_fct = nn.CrossEntropyLoss()
    # set global MLM mask prob to chosen
    global MLM_MASK_PROB
    prev_mlm_mask_prob = MLM_MASK_PROB
    MLM_MASK_PROB = float(best_hp.get("mlm_mask_prob", MLM_MASK_PROB))
    aux_coeff = 0.01

    try:
        final_epoch_pbar = tqdm(range(TRIAL_EPOCHS), desc="Final Training")
        for epoch in final_epoch_pbar:
            t_loss = train_one_epoch(final_model, dl_full, optimizer, mlm_loss_fct, nsp_loss_fct, pad_id, mask_id, vocab_size, DEVICE, aux_coeff)
            final_epoch_pbar.set_postfix({"loss": f"{t_loss:.6f}"})
            print(f"[Final Train] Epoch {epoch+1}/{TRIAL_EPOCHS} - loss {t_loss:.6f}")
    finally:
        MLM_MASK_PROB = prev_mlm_mask_prob

    # Evaluate final model on the held-out test set (never used in BO)
    test_dl = DataLoader(test_ds, batch_size=int(best_hp.get("batch_size", BATCH_SIZE)), shuffle=False, collate_fn=lambda b: collate_fn(b, pad_id))
    test_metrics = compute_metrics(final_model, test_dl, pad_id, mask_id, vocab_size, DEVICE)
    print("[Final Evaluation on TEST set]")
    print(json.dumps(test_metrics, indent=2))

    # Save final model and vocab into bo_logs/best_model/
    save_dir = os.path.join(BO_LOG_DIR, "best_model")
    os.makedirs(save_dir, exist_ok=True)
    torch.save(final_model.state_dict(), os.path.join(save_dir, "bert_encoder.pt"))
    with open(os.path.join(save_dir, "vocab.json"), "w") as f:
        json.dump({"stoi": stoi, "itos": itos}, f)
    # Save test metrics and BO records
    with open(os.path.join(save_dir, "test_metrics.json"), "w") as f:
        json.dump(test_metrics, f, indent=2)
    with open(os.path.join(BO_LOG_DIR, "bo_results.json"), "w") as f:
        json.dump(bo_records, f, indent=2)
    print(f"[BO] Best model, vocab and test metrics saved to {save_dir}")
    print(f"[BO] Best model, vocab and test metrics saved to {save_dir}")
    return {"best_record": best_record, "test_metrics": test_metrics, "bo_records": bo_records}
    return {"best_record": best_record, "test_metrics": test_metrics, "bo_records": bo_records}


# -------------------------
# Load corpus from ChromaDB
# -------------------------
def load_corpus_from_chromadb(db_path: str = "../VectorDB/chroma_Data", collection_name: str = "harry_potter_collection"):

    """
    Load corpus from ChromaDB collection.
    Returns list of tuples: [(text, dtype), ...] where dtype is 'Q' for query or 'C' for chunk
    """
    print(f"[INFO] Loading corpus from ChromaDB at: {db_path}")
    print(f"[INFO] Collection name: {collection_name}")
    
    try:
        client_db = chromadb.PersistentClient(path=db_path)
        collection = client_db.get_collection(name=collection_name)
        
        # Get all documents with metadata
        results = collection.get(include=["documents", "metadatas"])
        
        corpus = []
        chunk_count = 0
        query_count = 0
        
        for doc, meta in zip(results["documents"], results["metadatas"]):
            # Check if it's a chunk or query based on metadata
            if meta.get("ischunk") is True:
                corpus.append((doc, "C"))
                chunk_count += 1
            elif meta.get("ischunk") is False:
                corpus.append((doc, "Q"))
                query_count += 1
            else:
                # If metadata doesn't have ischunk, skip or assume it's a chunk
                print(f"[WARNING] Document without 'ischunk' metadata: {doc[:50]}...")
                corpus.append((doc, "C"))  # default to chunk
                chunk_count += 1
        
        print(f"[INFO] Loaded {len(corpus)} documents: {chunk_count} chunks, {query_count} queries")
        return corpus
        
    except Exception as e:
        print(f"[ERROR] Failed to load from ChromaDB: {e}")
        print("[INFO] Falling back to default corpus")
        return get_default_corpus()

# Entrypoint
# -------------------------
def main():
    # Load corpus from ChromaDB (or fall back to default if not available)
    corpus = load_corpus_from_chromadb()
    
    # If ChromaDB loading failed and we only have 4 items, it's the default corpus
    if len(corpus) == 4:
        print("[INFO] Using small default corpus for testing")
    
    print("[MAIN] Starting Bayesian Optimization over hyperparameters with 5-fold CV per BO eval...")
    results = run_bayesian_optimization_with_heldout_test_cv5(corpus)
    print("[MAIN] BO Completed. Summary:")
    print(json.dumps(results["best_record"], indent=2))
    print("Test metrics:")
    print(json.dumps(results["test_metrics"], indent=2))


if __name__ == "__main__":
    main()
    

[MAIN] Starting Bayesian Optimization over hyperparameters with 5-fold CV per BO eval...
[DATA SPLIT] total=4, train=2, val=1, test=1
[BO] Using combined train+val pool of size 3 for 5-fold CV in BO.
[BO] Starting initial random evaluations (random seed sampling)...


BO Initial Samples:   0%|          | 0/3 [00:00<?, ?it/s]

[BO] Init sample 1/3: {'learning_rate': 4.523556587070803e-06, 'moe_experts': 6, 'moe_k': 2, 'ffn_dim': 3584, 'num_layers': 12, 'num_heads': 12, 'word2vec_window': 3, 'mlm_mask_prob': 0.18062383981151098, 'batch_size': 4, 'word2vec_size': 768, 'word2vec_min_count': 1}


[Run 1] K-Fold CV:   0%|          | 0/3 [00:00<?, ?it/s][A
[A

[A[A
[A

[A[A

[A[A
[A

[A[A
BO Initial Samples:   0%|          | 0/3 [00:05<?, ?it/s]



RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!