In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import random
import csv
import chromadb
from torch.utils.data import Dataset, DataLoader
from typing import List, Dict, Tuple
from dataclasses import dataclass
import os

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", DEVICE)

# ============================================================
# 1. LOAD YOUR PRETRAINED BERT ENCODER MODEL
#    (This must match your previously saved architecture)
# ============================================================

PAD_TOKEN = "[PAD]"
CLS_TOKEN = "[CLS]"
SEP_TOKEN = "[SEP]"
MASK_TOKEN = "[MASK]"
UNK_TOKEN = "[UNK]"

SPECIAL_TOKENS = [PAD_TOKEN, CLS_TOKEN, SEP_TOKEN, MASK_TOKEN, UNK_TOKEN]

# ------------------------
# Transformer Layer
# ------------------------
class TransformerEncoderLayer(nn.Module):
    def __init__(self, hidden_size, num_heads, ffn_dim, dropout=0.1):
        super().__init__()
        self.self_attn = nn.MultiheadAttention(hidden_size, num_heads,
                                               dropout=dropout, batch_first=True)
        self.ln1 = nn.LayerNorm(hidden_size)
        self.ln2 = nn.LayerNorm(hidden_size)
        self.ffn = nn.Sequential(
            nn.Linear(hidden_size, ffn_dim),
            nn.GELU(),
            nn.Linear(ffn_dim, hidden_size)
        )
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask):
        key_padding_mask = (mask == 0)
        attn_out, _ = self.self_attn(x, x, x, key_padding_mask=key_padding_mask)
        x = self.ln1(x + self.dropout(attn_out))
        ffn_out = self.ffn(x)
        x = self.ln2(x + self.dropout(ffn_out))
        return x

# ------------------------
# Base BERT Encoder
# ------------------------
class BertEncoderModel(nn.Module):
    def __init__(self, vocab_size, hidden_size=768,
                 num_layers=12, num_heads=12, ffn_dim=3072,
                 max_position_embeddings=512, pad_token_id=0,
                 embedding_weights=None):
        super().__init__()
        self.pad_token_id = pad_token_id

        self.token_embeddings = nn.Embedding(vocab_size, hidden_size,
                                             padding_idx=pad_token_id)
        if embedding_weights is not None:
            self.token_embeddings.weight.data.copy_(embedding_weights)

        self.position_embeddings = nn.Embedding(max_position_embeddings, hidden_size)
        self.segment_embeddings = nn.Embedding(2, hidden_size)

        self.emb_ln = nn.LayerNorm(hidden_size)
        self.emb_dropout = nn.Dropout(0.1)

        self.layers = nn.ModuleList(
            [TransformerEncoderLayer(hidden_size, num_heads, ffn_dim)
             for _ in range(num_layers)]
        )

    def encode(self, ids, tt=None, mask=None):
        if tt is None:
            tt = torch.zeros_like(ids)
        if mask is None:
            mask = (ids != self.pad_token_id).long()

        pos = torch.arange(ids.size(1), device=ids.device).unsqueeze(0)
        x = (self.token_embeddings(ids) +
             self.position_embeddings(pos) +
             self.segment_embeddings(tt))

        x = self.emb_dropout(self.emb_ln(x))

        for layer in self.layers:
            x = layer(x, mask)

        return x[:, 0]   # CLS embedding

# ============================================================
# 2. APPLY LoRA TO THE BERT ATTENTION MODULES
# ============================================================

class LoRALinear(nn.Module):
    """A LoRA wrapper for linear layers."""
    def __init__(self, base_layer, r=8, alpha=8, dropout=0.0):
        super().__init__()
        self.base = base_layer
        self.r = r
        self.alpha = alpha
        self.scaling = alpha / r

        self.lora_A = nn.Linear(base_layer.in_features, r, bias=False)
        self.lora_B = nn.Linear(r, base_layer.out_features, bias=False)

        nn.init.kaiming_uniform_(self.lora_A.weight, a=math.sqrt(5))
        nn.init.zeros_(self.lora_B.weight)

        self.dropout = nn.Dropout(dropout)

        self.train_lora_only()

    def train_lora_only(self):
        """Freeze base layer, train LoRA params only."""
        for p in self.base.parameters():
            p.requires_grad = False
        for p in self.lora_A.parameters():
            p.requires_grad = True
        for p in self.lora_B.parameters():
            p.requires_grad = True

    def forward(self, x):
        return self.base(x) + self.dropout(self.lora_B(self.lora_A(x))) * self.scaling

def apply_lora(model, r=8, alpha=8):
    """Wraps all MHA Q,K,V projection layers with LoRA."""
    for layer in model.layers:
        attn = layer.self_attn

        attn.in_proj_weight = nn.Parameter(attn.in_proj_weight, requires_grad=False)
        attn.in_proj_bias = nn.Parameter(attn.in_proj_bias, requires_grad=False)

        hidden = attn.embed_dim
        q_proj = nn.Linear(hidden, hidden)
        k_proj = nn.Linear(hidden, hidden)
        v_proj = nn.Linear(hidden, hidden)

        attn.q_proj_weight = q_proj.weight
        attn.k_proj_weight = k_proj.weight
        attn.v_proj_weight = v_proj.weight

        attn.q_proj = LoRALinear(q_proj, r=r, alpha=alpha)
        attn.k_proj = LoRALinear(k_proj, r=r, alpha=alpha)
        attn.v_proj = LoRALinear(v_proj, r=r, alpha=alpha)

    return model

# ============================================================
# 3. CHROMADB LOADING
# ============================================================
client = chromadb.PersistentClient(path="chroma_db/")
collection = client.get_collection("rag_collection")

# ============================================================
# 4. TRAINING DATASET (CSV WITH query, chunk_ID)
# ============================================================

@dataclass
class QueryChunkPair:
    query: str
    positive_id: str

class ContrastiveDataset(Dataset):
    def __init__(self, csv_path):
        self.pairs: List[QueryChunkPair] = []
        with open(csv_path, "r") as f:
            reader = csv.reader(f)
            next(reader)
            for row in reader:
                self.pairs.append(QueryChunkPair(query=row[0], positive_id=row[1]))

        all_chunks = collection.get(
            where={"isChunk": True}
        )
        self.chunk_ids = all_chunks["ids"]

    def __len__(self):
        return len(self.pairs)

    def sample_negatives(self, positive_id: str, k=5):
        negs = []
        while len(negs) < k:
            cid = random.choice(self.chunk_ids)
            if cid != positive_id:
                negs.append(cid)
        return negs

    def __getitem__(self, idx):
        item = self.pairs[idx]

        pos_chunk = collection.get(ids=[item.positive_id])
        neg_ids = self.sample_negatives(item.positive_id, k=5)
        neg_chunks = collection.get(ids=neg_ids)

        return {
            "query": item.query,
            "positive_text": pos_chunk["documents"][0],
            "negative_texts": neg_chunks["documents"]
        }

# ============================================================
# 5. TOKENIZER (simple whitespace tokenizer)
# ============================================================

def tokenize(text, stoi, max_len=64):
    tokens = text.strip().split()
    ids = [stoi.get(t, stoi[UNK_TOKEN]) for t in tokens][: max_len-2]
    ids = [stoi[CLS_TOKEN]] + ids + [stoi[SEP_TOKEN]]
    tt = [0] * len(ids)
    mask = [1] * len(ids)
    return (
        torch.tensor(ids, dtype=torch.long),
        torch.tensor(tt, dtype=torch.long),
        torch.tensor(mask, dtype=torch.long)
    )

def collate(batch):
    queries, q_tt, q_mask = [], [], []
    pos, pt_tt, pt_mask = [], [], []
    negs, nt_tt, nt_mask = [], [], []

    max_len = 0
    for b in batch:
        max_len = max(max_len, len(b["query"].split())+2,
                      len(b["positive_text"].split())+2)

    for b in batch:
        q_ids, qtt, qmask = tokenize(b["query"], stoi, max_len)
        p_ids, ptt, pmask = tokenize(b["positive_text"], stoi, max_len)

        neg_ids_list = []
        neg_tt_list = []
        neg_mask_list = []

        for nt in b["negative_texts"]:
            ids, tti, msk = tokenize(nt, stoi, max_len)
            neg_ids_list.append(ids)
            neg_tt_list.append(tti)
            neg_mask_list.append(msk)

        queries.append(q_ids)
        q_tt.append(qtt)
        q_mask.append(qmask)

        pos.append(p_ids)
        pt_tt.append(ptt)
        pt_mask.append(pmask)

        negs.append(torch.stack(neg_ids_list))
        nt_tt.append(torch.stack(neg_tt_list))
        nt_mask.append(torch.stack(neg_mask_list))

    return {
        "queries": torch.stack(queries),
        "queries_tt": torch.stack(q_tt),
        "queries_mask": torch.stack(q_mask),

        "pos": torch.stack(pos),
        "pos_tt": torch.stack(pt_tt),
        "pos_mask": torch.stack(pt_mask),

        "neg": torch.stack(negs),
        "neg_tt": torch.stack(nt_tt),
        "neg_mask": torch.stack(nt_mask),
    }

# ============================================================
# 6. CONTRASTIVE LOSS (softmax over exp(cos()))
# ============================================================

def contrastive_loss(q, pos, negs):
    """
    q: (B, H)
    pos: (B, H)
    negs: (B, 5, H)
    """

    pos_sim = F.cosine_similarity(q, pos)      # (B,)
    neg_sim = F.cosine_similarity(
        q.unsqueeze(1).repeat(1, negs.size(1), 1),
        negs,
        dim=-1
    )  # (B, 5)

    sims = torch.cat([pos_sim.unsqueeze(1), neg_sim], dim=1)  # (B, 6)
    exp_sims = torch.exp(sims)
    probs = exp_sims / exp_sims.sum(dim=1, keepdim=True)

    loss = -torch.log(probs[:, 0]).mean()
    return loss

# ============================================================
# 7. MAIN TRAINING LOOP
# ============================================================

def train_lora(
    model_path="saved_bert_encoder/bert_encoder.pt",
    vocab_path="saved_bert_encoder/vocab.json",
    csv_path="positive_pairs.csv",
    batch_size=4,
    lr=1e-4,
    epochs=3,
):

    import json
    with open(vocab_path, "r") as f:
        vocab = json.load(f)

    global stoi, itos
    stoi, itos = vocab["stoi"], vocab["itos"]
    vocab_size = len(itos)

    model = BertEncoderModel(vocab_size)
    model.load_state_dict(torch.load(model_path, map_location=DEVICE))

    apply_lora(model)

    model.to(DEVICE)

    ds = ContrastiveDataset(csv_path)
    dl = DataLoader(ds, batch_size=batch_size, shuffle=True, collate_fn=collate)

    opt = torch.optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=lr)

    model.train()
    for epoch in range(epochs):
        for batch in dl:
            q = batch["queries"].to(DEVICE)
            q_tt = batch["queries_tt"].to(DEVICE)
            q_mask = batch["queries_mask"].to(DEVICE)

            p = batch["pos"].to(DEVICE)
            p_tt = batch["pos_tt"].to(DEVICE)
            p_mask = batch["pos_mask"].to(DEVICE)

            n = batch["neg"].to(DEVICE)
            n_tt = batch["neg_tt"].to(DEVICE)
            n_mask = batch["neg_mask"].to(DEVICE)

            q_emb = model.encode(q, q_tt, q_mask)
            p_emb = model.encode(p, p_tt, p_mask)

            B, K, L = n.size()
            n = n.view(B*K, L)
            n_tt = n_tt.view(B*K, L)
            n_mask = n_mask.view(B*K, L)
            n_emb = model.encode(n, n_tt, n_mask).view(B, K, -1)

            loss = contrastive_loss(q_emb, p_emb, n_emb)

            opt.zero_grad()
            loss.backward()
            opt.step()

            print(f"Epoch {epoch} Loss {loss.item():.4f}")

    os.makedirs("lora_finetuned", exist_ok=True)
    torch.save(model.state_dict(), "lora_finetuned/lora_bert.pt")
    print("LoRA fine-tuned model saved.")

# ============================================================
# RUN TRAINING
# ============================================================

if __name__ == "__main__":
    train_lora()