In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import random
import csv
import chromadb
import math
from torch.utils.data import Dataset, DataLoader
from typing import List, Dict, Tuple
from dataclasses import dataclass
import os

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", DEVICE)

# ============================================================
# 1. LOAD YOUR PRETRAINED BERT ENCODER MODEL WITH MoE
#    (This must match your previously saved architecture)
# ============================================================

PAD_TOKEN = "[PAD]"
CLS_TOKEN = "[CLS]"
SEP_TOKEN = "[SEP]"
MASK_TOKEN = "[MASK]"
UNK_TOKEN = "[UNK]"

SPECIAL_TOKENS = [PAD_TOKEN, CLS_TOKEN, SEP_TOKEN, MASK_TOKEN, UNK_TOKEN]

# ------------------------
# Mixture of Experts (MoE)
# ------------------------
class MoE(nn.Module):
    def __init__(self, hidden_size, ffn_dim, num_experts=5, k=2, noise_std=1.0):
        super().__init__()
        self.hidden_size = hidden_size
        self.ffn_dim = ffn_dim
        self.num_experts = num_experts
        self.k = k
        self.noise_std = noise_std
        
        self.experts = nn.ModuleList([
            nn.Sequential(
                nn.Linear(hidden_size, ffn_dim),
                nn.GELU(),
                nn.Linear(ffn_dim, hidden_size)
            ) for _ in range(num_experts)
        ])
        
        self.router = nn.Linear(hidden_size, num_experts)
    
    def forward(self, x, mask=None):
        B, S, H = x.size()
        logits = self.router(x)
        probs_all = F.softmax(logits, dim=-1)
        importance = probs_all.sum(dim=(0, 1))
        total_tokens = float(B * S)
        aux_loss = (self.num_experts * (importance / total_tokens).pow(2).sum())
        
        if self.training:
            noise = torch.randn_like(logits) * self.noise_std
            logits_noisy = logits + noise
        else:
            logits_noisy = logits
        
        topk_vals, topk_idx = torch.topk(logits_noisy, self.k, dim=-1)
        topk_weights = F.softmax(topk_vals, dim=-1)
        
        expert_outs = []
        for e in range(self.num_experts):
            expert_outs.append(self.experts[e](x))
        expert_stack = torch.stack(expert_outs, dim=2)
        
        device = x.device
        gating = torch.zeros(B, S, self.num_experts, device=device, dtype=x.dtype)
        flat_idx = topk_idx.view(-1, self.k)
        flat_w = topk_weights.view(-1, self.k)
        gating_flat = gating.view(-1, self.num_experts)
        rows = torch.arange(gating_flat.size(0), device=device).unsqueeze(1).expand(-1, self.k)
        gating_flat.scatter_(1, flat_idx, flat_w)
        gating = gating_flat.view(B, S, self.num_experts)
        
        out = torch.einsum('bse,bseh->bsh', gating, expert_stack)
        return out, aux_loss

# ------------------------
# Transformer Layer with MoE
# ------------------------
class TransformerEncoderLayer(nn.Module):
    def __init__(self, hidden_size, num_heads, ffn_dim, dropout=0.1, moe_experts=5, moe_k=2):
        super().__init__()
        self.self_attn = nn.MultiheadAttention(hidden_size, num_heads,
                                               dropout=dropout, batch_first=True)
        self.ln1 = nn.LayerNorm(hidden_size)
        self.ln2 = nn.LayerNorm(hidden_size)
        self.ffn_moe = MoE(hidden_size, ffn_dim, num_experts=moe_experts, k=moe_k)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask):
        key_padding_mask = (mask == 0)
        attn_out, _ = self.self_attn(x, x, x, key_padding_mask=key_padding_mask)
        x = self.ln1(x + self.dropout(attn_out))
        ffn_out, aux_loss = self.ffn_moe(x, mask)
        x = self.ln2(x + self.dropout(ffn_out))
        return x, aux_loss

# ------------------------
# Base BERT Encoder with MoE
# ------------------------
class BertEncoderModel(nn.Module):
    def __init__(self, vocab_size, hidden_size=768,
                 num_layers=12, num_heads=12, ffn_dim=3072,
                 max_position_embeddings=512, pad_token_id=0,
                 moe_experts=5, moe_k=2, embedding_weights=None):
        super().__init__()
        self.pad_token_id = pad_token_id
        self.hidden_size = hidden_size

        self.token_embeddings = nn.Embedding(vocab_size, hidden_size,
                                             padding_idx=pad_token_id)
        if embedding_weights is not None:
            self.token_embeddings.weight.data.copy_(embedding_weights)

        self.position_embeddings = nn.Embedding(max_position_embeddings, hidden_size)
        self.segment_embeddings = nn.Embedding(2, hidden_size)

        self.emb_ln = nn.LayerNorm(hidden_size)
        self.emb_dropout = nn.Dropout(0.1)

        self.layers = nn.ModuleList([
            TransformerEncoderLayer(hidden_size, num_heads, ffn_dim, dropout=0.1,
                                   moe_experts=moe_experts, moe_k=moe_k)
            for _ in range(num_layers)
        ])
        
        # These are in the saved model but not used for fine-tuning
        self.nsp_classifier = nn.Sequential(
            nn.Linear(hidden_size, hidden_size),
            nn.Tanh(),
            nn.Linear(hidden_size, 2)
        )
        self.mlm_bias = nn.Parameter(torch.zeros(vocab_size))

    def encode(self, ids, tt=None, mask=None):
        if tt is None:
            tt = torch.zeros_like(ids)
        if mask is None:
            mask = (ids != self.pad_token_id).long()

        pos = torch.arange(ids.size(1), device=ids.device).unsqueeze(0)
        x = (self.token_embeddings(ids) +
             self.position_embeddings(pos) +
             self.segment_embeddings(tt))

        x = self.emb_dropout(self.emb_ln(x))

        total_aux = 0.0
        for layer in self.layers:
            x, aux_loss = layer(x, mask)
            total_aux = total_aux + aux_loss

        return x[:, 0]   # CLS embedding

# ============================================================
# 2. APPLY LoRA TO THE BERT ATTENTION MODULES
# ============================================================

class LoRALinear(nn.Module):
    """A LoRA wrapper for linear layers."""
    def __init__(self, base_layer, r=8, alpha=8, dropout=0.0):
        super().__init__()
        self.base = base_layer
        self.r = r
        self.alpha = alpha
        self.scaling = alpha / r

        self.lora_A = nn.Linear(base_layer.in_features, r, bias=False)
        self.lora_B = nn.Linear(r, base_layer.out_features, bias=False)

        nn.init.kaiming_uniform_(self.lora_A.weight, a=math.sqrt(5))
        nn.init.zeros_(self.lora_B.weight)

        self.dropout = nn.Dropout(dropout)

        self.train_lora_only()

    def train_lora_only(self):
        """Freeze base layer, train LoRA params only."""
        for p in self.base.parameters():
            p.requires_grad = False
        for p in self.lora_A.parameters():
            p.requires_grad = True
        for p in self.lora_B.parameters():
            p.requires_grad = True

    def forward(self, x):
        return self.base(x) + self.dropout(self.lora_B(self.lora_A(x))) * self.scaling

def apply_lora(model, r=8, alpha=8):
    """Wraps all MHA Q,K,V projection layers with LoRA."""
    for layer in model.layers:
        attn = layer.self_attn

        attn.in_proj_weight = nn.Parameter(attn.in_proj_weight, requires_grad=False)
        attn.in_proj_bias = nn.Parameter(attn.in_proj_bias, requires_grad=False)

        hidden = attn.embed_dim
        q_proj = nn.Linear(hidden, hidden)
        k_proj = nn.Linear(hidden, hidden)
        v_proj = nn.Linear(hidden, hidden)

        attn.q_proj_weight = q_proj.weight
        attn.k_proj_weight = k_proj.weight
        attn.v_proj_weight = v_proj.weight

        attn.q_proj = LoRALinear(q_proj, r=r, alpha=alpha)
        attn.k_proj = LoRALinear(k_proj, r=r, alpha=alpha)
        attn.v_proj = LoRALinear(v_proj, r=r, alpha=alpha)

    return model

# ============================================================
# 3. CHROMADB LOADING
# ============================================================
client = chromadb.PersistentClient(path="../VectorDB/chroma_Data_with_BERT_embeddings")
collection = client.get_collection("HP_Chunks_BERT_Embeddings_collection")

# ============================================================
# 4. TRAINING DATASET (CSV WITH query, chunk_ID)
# ============================================================

@dataclass
class QueryChunkPair:
    query: str
    positive_id: str

class ContrastiveDataset(Dataset):
    def __init__(self, csv_path):
        self.pairs: List[QueryChunkPair] = []
        with open(csv_path, "r") as f:
            reader = csv.reader(f)
            next(reader)
            for row in reader:
                # Expecting: query, positive_id
                if len(row) < 2:
                    continue
                self.pairs.append(QueryChunkPair(query=row[0], positive_id=row[1]))

        # Query chroma for chunk entries. Adjust the where-clause if your metadata differs.
        all_chunks = collection.get(where={"ischunk": True})
        # Defensive: ensure "ids" key exists and is a list
        ids = all_chunks.get("ids") if isinstance(all_chunks, dict) else None
        if not ids:
            # No chunk ids found â€” raise a clear error with suggestions
            raise RuntimeError(
                "No chunk ids found in Chroma collection for where={'ischunk': True}. "
                "Check collection path and metadata keys. "
                "Got response keys: {}. Response summary: {}".format(
                    list(all_chunks.keys()) if isinstance(all_chunks, dict) else str(type(all_chunks)),
                    {k: (len(all_chunks[k]) if isinstance(all_chunks.get(k), list) else type(all_chunks.get(k)))
                     for k in (all_chunks.keys() if isinstance(all_chunks, dict) else [])}
                )
            )

        self.chunk_ids = list(ids)

        # Optionally create a set for fast membership tests
        self.chunk_id_set = set(self.chunk_ids)

        # Basic sanity: confirm positive IDs from CSV exist (warn only)
        missing_positives = [p.positive_id for p in self.pairs if p.positive_id not in self.chunk_id_set]
        if missing_positives:
            # Show a short sample to not flood logs
            sample_missing = missing_positives[:10]
            print(f"[Warning] {len(missing_positives)} positive IDs from CSV not found in Chroma collection. Examples: {sample_missing}")
            # You may choose to raise here if this is critical:
            # raise RuntimeError("Positive IDs mismatch between CSV and Chroma collection.")

    def __len__(self):
        return len(self.pairs)

    def sample_negatives(self, positive_id: str, k=5):
        """
        Robust negative sampling:
        - Build candidate list excluding the positive_id
        - If there are >= k distinct candidates: use random.sample (no replacement)
        - If there are between 1 and k-1 candidates: sample with replacement from candidates
        - If there are 0 candidates: raise informative error
        """
        # Build candidate pool excluding the positive
        if not self.chunk_ids:
            raise RuntimeError("No chunk ids available in dataset (self.chunk_ids is empty). Check Chroma collection population.")
        candidates = [cid for cid in self.chunk_ids if cid != positive_id]

        if len(candidates) == 0:
            raise RuntimeError(
                f"No negative candidates available for positive_id={positive_id}. "
                "Either the collection only contains that one id or positive_id doesn't exist in collection."
            )

        if len(candidates) >= k:
            return random.sample(candidates, k)
        else:
            # sample with replacement to meet required k
            return [random.choice(candidates) for _ in range(k)]

    def __getitem__(self, idx):
        item = self.pairs[idx]

        # Validate positive id exists
        if item.positive_id not in self.chunk_id_set:
            # Return a helpful error rather than failing deep inside
            raise RuntimeError(f"Positive id '{item.positive_id}' from CSV not found in Chroma collection.")

        pos_chunk = collection.get(ids=[item.positive_id])
        pos_documents = pos_chunk.get("documents", [])
        if not pos_documents:
            raise RuntimeError(f"No document found in Chroma for positive id {item.positive_id} (collection.get returned empty).")

        neg_ids = self.sample_negatives(item.positive_id, k=5)
        neg_chunks = collection.get(ids=neg_ids)
        neg_documents = neg_chunks.get("documents", [])

        # If some negatives are missing, fallback to empty strings (or raise if you prefer)
        if len(neg_documents) < len(neg_ids):
            # fill missing with empty strings to keep shapes consistent
            filled = []
            for i, cid in enumerate(neg_ids):
                try:
                    doc = neg_documents[i]
                except Exception:
                    doc = ""
                filled.append(doc)
            neg_documents = filled

        return {
            "query": item.query,
            "positive_text": pos_documents[0],
            "negative_texts": neg_documents
        }

# ============================================================
# 5. TOKENIZER (simple whitespace tokenizer)
# ============================================================

def tokenize(text, stoi, max_len=64):
    tokens = text.strip().split()
    ids = [stoi.get(t, stoi[UNK_TOKEN]) for t in tokens][: max_len-2]
    ids = [stoi[CLS_TOKEN]] + ids + [stoi[SEP_TOKEN]]
    tt = [0] * len(ids)
    mask = [1] * len(ids)
    return (
        torch.tensor(ids, dtype=torch.long),
        torch.tensor(tt, dtype=torch.long),
        torch.tensor(mask, dtype=torch.long)
    )

def collate(batch):
    queries, q_tt, q_mask = [], [], []
    pos, pt_tt, pt_mask = [], [], []
    negs, nt_tt, nt_mask = [], [], []

    # Find the true max length across all queries, positives, and negatives in the batch
    pad_len = 0
    for b in batch:
        pad_len = max(pad_len, len(b["query"].split())+2, len(b["positive_text"].split())+2)
        for nt in b["negative_texts"]:
            pad_len = max(pad_len, len(nt.strip().split())+2)

    # === START PATCH: cap sequence length to avoid MoE OOM ===
    MAX_LEN = 128
    pad_len = min(pad_len, MAX_LEN)
    # === END PATCH ===

    def pad_tensor(tensor, length, pad_value):
        if tensor.size(0) < length:
            pad_amt = length - tensor.size(0)
            return torch.cat([tensor, torch.full((pad_amt,), pad_value, dtype=tensor.dtype)])
        elif tensor.size(0) > length:
            return tensor[:length]
        return tensor

    for b in batch:
        q_ids, qtt, qmask = tokenize(b["query"], stoi, pad_len)
        p_ids, ptt, pmask = tokenize(b["positive_text"], stoi, pad_len)

        # Always pad/truncate queries and positives
        q_ids = pad_tensor(q_ids, pad_len, stoi[PAD_TOKEN])
        qtt = pad_tensor(qtt, pad_len, 0)
        qmask = pad_tensor(qmask, pad_len, 0)
        p_ids = pad_tensor(p_ids, pad_len, stoi[PAD_TOKEN])
        ptt = pad_tensor(ptt, pad_len, 0)
        pmask = pad_tensor(pmask, pad_len, 0)

        neg_ids_list = []
        neg_tt_list = []
        neg_mask_list = []

        for nt in b["negative_texts"]:
            ids, tti, msk = tokenize(nt, stoi, pad_len)
            ids = pad_tensor(ids, pad_len, stoi[PAD_TOKEN])
            tti = pad_tensor(tti, pad_len, 0)
            msk = pad_tensor(msk, pad_len, 0)
            neg_ids_list.append(ids)
            neg_tt_list.append(tti)
            neg_mask_list.append(msk)

        queries.append(q_ids)
        q_tt.append(qtt)
        q_mask.append(qmask)

        pos.append(p_ids)
        pt_tt.append(ptt)
        pt_mask.append(pmask)

        negs.append(torch.stack(neg_ids_list))
        nt_tt.append(torch.stack(neg_tt_list))
        nt_mask.append(torch.stack(neg_mask_list))

    return {
        "queries": torch.stack(queries),
        "queries_tt": torch.stack(q_tt),
        "queries_mask": torch.stack(q_mask),

        "pos": torch.stack(pos),
        "pos_tt": torch.stack(pt_tt),
        "pos_mask": torch.stack(pt_mask),

        "neg": torch.stack(negs),
        "neg_tt": torch.stack(nt_tt),
        "neg_mask": torch.stack(nt_mask),
    }

# ============================================================
# 6. CONTRASTIVE LOSS (softmax over exp(cos()))
# ============================================================

def contrastive_loss(q, pos, negs):
    """
    q: (B, H)
    pos: (B, H)
    negs: (B, 5, H)
    """

    pos_sim = F.cosine_similarity(q, pos)      # (B,)
    neg_sim = F.cosine_similarity(
        q.unsqueeze(1).repeat(1, negs.size(1), 1),
        negs,
        dim=-1
    )  # (B, 5)

    sims = torch.cat([pos_sim.unsqueeze(1), neg_sim], dim=1)  # (B, 6)
    exp_sims = torch.exp(sims)
    probs = exp_sims / exp_sims.sum(dim=1, keepdim=True)

    loss = -torch.log(probs[:, 0]).mean()
    return loss

# ============================================================
# 7. MAIN TRAINING LOOP
# ============================================================

def train_lora(
    model_path="../Encoder/saved_bert_encoder_moe_pooling/bert_encoder_moe_pooling.pt",
    vocab_path="../Encoder/saved_bert_encoder_moe_pooling/vocab.json",
    csv_path="../LLM Caller/generated_pairs_without_commas.csv",
    batch_size=1,
    lr=1e-4,
    epochs=5,
):

    import json
    with open(vocab_path, "r") as f:
        vocab = json.load(f)

    global stoi, itos
    stoi, itos = vocab["stoi"], vocab["itos"]
    vocab_size = len(itos)

    # Create model with MoE architecture matching the saved model
    model = BertEncoderModel(vocab_size, max_position_embeddings=512, 
                            moe_experts=5, moe_k=2)
    model.load_state_dict(torch.load(model_path, map_location=DEVICE))

    apply_lora(model)

    model.to(DEVICE)

    ds = ContrastiveDataset(csv_path)
    dl = DataLoader(ds, batch_size=batch_size, shuffle=True, collate_fn=collate)

    opt = torch.optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=lr)

    model.train()
    for epoch in range(epochs):
        for batch in dl:
            q = batch["queries"].to(DEVICE)
            q_tt = batch["queries_tt"].to(DEVICE)
            q_mask = batch["queries_mask"].to(DEVICE)

            p = batch["pos"].to(DEVICE)
            p_tt = batch["pos_tt"].to(DEVICE)
            p_mask = batch["pos_mask"].to(DEVICE)

            n = batch["neg"].to(DEVICE)
            n_tt = batch["neg_tt"].to(DEVICE)
            n_mask = batch["neg_mask"].to(DEVICE)

            q_emb = model.encode(q, q_tt, q_mask)
            p_emb = model.encode(p, p_tt, p_mask)

            B, K, L = n.size()
            n = n.view(B*K, L)
            n_tt = n_tt.view(B*K, L)
            n_mask = n_mask.view(B*K, L)

            # micro-batch to prevent OOM
            chunks = []
            MB = 2  # micro-batch size (2 or even 1)
            for i in range(0, B*K, MB):
                part = n[i:i+MB]
                part_tt = n_tt[i:i+MB]
                part_mask = n_mask[i:i+MB]
                chunks.append(model.encode(part, part_tt, part_mask))

            n_emb = torch.cat(chunks, dim=0).view(B, K, -1)

            loss = contrastive_loss(q_emb, p_emb, n_emb)

            opt.zero_grad()
            loss.backward()
            opt.step()

            print(f"Epoch {epoch} Loss {loss.item():.4f}")

    os.makedirs("lora_finetuned", exist_ok=True)
    torch.save(model.state_dict(), "lora_finetuned/lora_bert.pt")
    print("LoRA fine-tuned model saved.")

# ============================================================
# RUN TRAINING
# ============================================================

if __name__ == "__main__":
    train_lora()

Using device: cuda


OutOfMemoryError: CUDA out of memory. Tried to allocate 20.00 MiB. GPU 0 has a total capacity of 10.75 GiB of which 14.62 MiB is free. Including non-PyTorch memory, this process has 7.66 GiB memory in use. Process 1559818 has 3.07 GiB memory in use. Of the allocated memory 6.83 GiB is allocated by PyTorch, and 691.59 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)